From d77f335ff0c3e39252f71f00051a16cfea63b876 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 20 Feb 2021 17:01:27 -0800 Subject: [PATCH] feat: add `LawfulMonad` instance for `ExceptT` --- src/Init/Control/Except.lean | 21 ++++---- src/Init/Control/Lawful.lean | 95 ++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 9 deletions(-) diff --git a/src/Init/Control/Except.lean b/src/Init/Control/Except.lean index 9303cf8797..13fd5ec154 100644 --- a/src/Init/Control/Except.lean +++ b/src/Init/Control/Except.lean @@ -10,36 +10,39 @@ import Init.Control.Basic import Init.Control.Id import Init.Coe -universes u v w u' - namespace Except variable {ε : Type u} -@[inline] protected def pure {α : Type v} (a : α) : Except ε α := +@[inline] protected def pure (a : α) : Except ε α := Except.ok a -@[inline] protected def map {α β : Type v} (f : α → β) : Except ε α → Except ε β +@[inline] protected def map (f : α → β) : Except ε α → Except ε β | Except.error err => Except.error err | Except.ok v => Except.ok <| f v -@[inline] protected def mapError {ε' : Type u} {α : Type v} (f : ε → ε') : Except ε α → Except ε' α +@[simp] theorem map_id : Except.map (ε := ε) (α := α) (β := α) id = id := by + apply funext + intro e + simp [Except.map]; cases e <;> rfl + +@[inline] protected def mapError (f : ε → ε') : Except ε α → Except ε' α | Except.error err => Except.error <| f err | Except.ok v => Except.ok v -@[inline] protected def bind {α β : Type v} (ma : Except ε α) (f : α → Except ε β) : Except ε β := +@[inline] protected def bind (ma : Except ε α) (f : α → Except ε β) : Except ε β := match ma with | Except.error err => Except.error err | Except.ok v => f v -@[inline] protected def toBool {α : Type v} : Except ε α → Bool +@[inline] protected def toBool : Except ε α → Bool | Except.ok _ => true | Except.error _ => false -@[inline] protected def toOption {α : Type v} : Except ε α → Option α +@[inline] protected def toOption : Except ε α → Option α | Except.ok a => some a | Except.error _ => none -@[inline] protected def tryCatch {α : Type u} (ma : Except ε α) (handle : ε → Except ε α) : Except ε α := +@[inline] protected def tryCatch (ma : Except ε α) (handle : ε → Except ε α) : Except ε α := match ma with | Except.ok a => Except.ok a | Except.error e => handle e diff --git a/src/Init/Control/Lawful.lean b/src/Init/Control/Lawful.lean index 5727347b3a..a0ec869a98 100644 --- a/src/Init/Control/Lawful.lean +++ b/src/Init/Control/Lawful.lean @@ -5,6 +5,7 @@ Authors: Sebastian Ullrich, Leonardo de Moura -/ prelude import Init.SimpLemmas +import Init.Control.Except open Function @@ -24,6 +25,9 @@ class LawfulApplicative (f : Type u → Type v) [Applicative f] extends LawfulFu map_pure (g : α → β) (x : α) : g <$> (pure x : f α) = pure (g x) seq_pure (g : f (α → β)) (x : α) : g <*> pure x = (fun h : α → β => h x) <$> g seq_assoc (x : f α) (g : f (α → β)) (h : f (β → γ)) : h <*> (g <*> x) = (@comp α β γ <$> h) <*> g <*> x + comp_map g h x := by + repeat rw [← pure_seq] + simp [seq_assoc, map_pure, seq_pure] export LawfulApplicative (seqLeft_eq seqRight_eq pure_seq map_pure seq_pure seq_assoc) @@ -37,6 +41,15 @@ class LawfulMonad (m : Type u → Type v) [Monad m] extends LawfulApplicative m bind_map (f : m (α → (β : Type u))) (x : m α) : f >>= (. <$> x) = f <*> x pure_bind (x : α) (f : α → m β) : pure x >>= f = f x bind_assoc (x : m α) (f : α → m β) (g : β → m γ) : x >>= f >>= g = x >>= fun x => f x >>= g + map_pure g x := by rw [← bind_pure_comp, pure_bind] + seq_pure g x := by rw [← bind_map]; simp [map_pure, bind_pure_comp] + seq_assoc x g h := by + -- TODO: support for applying `symm` at `simp` arguments + let bind_pure_comp_symm {α β : Type u} (f : α → β) (x : m α) : f <$> x = x >>= pure ∘ f := by + rw [bind_pure_comp] + let bind_map_symm {α β : Type u} (f : m (α → (β : Type u))) (x : m α) : f <*> x = f >>= (. <$> x) := by + rw [bind_map] + simp[bind_pure_comp_symm, bind_map_symm, bind_assoc, pure_bind] export LawfulMonad (bind_pure_comp bind_map pure_bind bind_assoc) attribute [simp] pure_bind bind_assoc @@ -53,3 +66,85 @@ theorem bind_congr [Bind m] {x : m α} {f g : α → m β} (h : ∀ a, f a = g a theorem map_congr [Functor m] {x : m α} {f g : α → β} (h : ∀ a, f a = g a) : (f <$> x : m β) = g <$> x := by simp [funext h] + +/- Id -/ + +namespace Id + +@[simp] theorem map_eq (x : Id α) (f : α → β) : f <$> x = f x := rfl +@[simp] theorem bind_eq (x : Id α) (f : α → id β) : x >>= f = f x := rfl +@[simp] theorem pure_eq (a : α) : (pure a : Id α) = a := rfl + +instance : LawfulMonad Id := by + refine! { .. } <;> intros <;> rfl + +end Id + +/- ExceptT -/ + +namespace ExceptT + +theorem ext [Monad m] {x y : ExceptT ε m α} (h : x.run = y.run) : x = y := by + simp [run] at h + assumption + +@[simp] theorem run_pure [Monad m] : run (pure x : ExceptT ε m α) = pure (Except.ok x) := rfl +@[simp] theorem run_lift [Monad m] : run (ExceptT.lift x : ExceptT ε m α) = Except.ok <$> x := rfl +@[simp] theorem run_throw [Monad m] : run (throw e : ExceptT ε m β) = pure (Except.error e) := rfl +@[simp] theorem run_bind [Monad m] (x : ExceptT ε m α) + : run (x >>= f : ExceptT ε m β) + = + run x >>= fun + | Except.ok x => run (f x) + | Except.error e => pure (Except.error e) := + rfl + +@[simp] theorem lift_pure [Monad m] [LawfulMonad m] (a : α) : ExceptT.lift (pure a) = (pure a : ExceptT ε m α) := by + simp [ExceptT.lift, pure, ExceptT.pure] + +@[simp] theorem run_map [Monad m] [LawfulMonad m] (f : α → β) (x : ExceptT ε m α) + : (f <$> x).run = Except.map f <$> x.run := by + rw [← bind_pure_comp (m := m)] + simp [Functor.map, ExceptT.map] + apply bind_congr + intro a; cases a <;> simp [Except.map] + +protected theorem seq_eq {α β ε : Type u} [Monad m] (mf : ExceptT ε m (α → β)) (x : ExceptT ε m α) : mf <*> x = mf >>= fun f => f <$> x := + rfl + +protected theorem bind_pure_comp [Monad m] [LawfulMonad m] (f : α → β) (x : ExceptT ε m α) : x >>= pure ∘ f = f <$> x := by + intros; rfl + +protected theorem seqLeft_eq {α β ε : Type u} {m : Type u → Type v} [Monad m] [LawfulMonad m] (x : ExceptT ε m α) (y : ExceptT ε m β) : x <* y = const β <$> x <*> y := by + show (x >>= fun a => y >>= fun _ => pure a) = (const (α := α) β <$> x) >>= fun f => f <$> y + rw [← ExceptT.bind_pure_comp] + apply ext + simp + apply bind_congr + intro a + cases a with + | error => simp + | ok => + simp; rw [← bind_pure_comp]; apply bind_congr; intro b; + cases b <;> simp [comp, Except.map, const] + +protected theorem seqRight_eq [Monad m] [LawfulMonad m] (x : ExceptT ε m α) (y : ExceptT ε m β) : x *> y = const α id <$> x <*> y := by + show (x >>= fun _ => y) = (const α id <$> x) >>= fun f => f <$> y + rw [← ExceptT.bind_pure_comp] + apply ext + simp + apply bind_congr + intro a; cases a <;> simp + +instance [Monad m] [LawfulMonad m] : LawfulMonad (ExceptT ε m) where + id_map := by intros; apply ext; simp + map_const := by intros; rfl + seqLeft_eq := ExceptT.seqLeft_eq + seqRight_eq := ExceptT.seqRight_eq + pure_seq := by intros; apply ext; simp [ExceptT.seq_eq] + bind_pure_comp := ExceptT.bind_pure_comp + bind_map := by intros; rfl + pure_bind := by intros; apply ext; simp + bind_assoc := by intros; apply ext; simp; apply bind_congr; intro a; cases a <;> simp + +end ExceptT