From 968457ac1ca9965fd5427cf308f9de3bdc220a2b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 20 Aug 2020 13:11:12 -0700 Subject: [PATCH] feat: make sure `StateRefT` can be used with any base `EIO Exception` --- src/Init/Control/StateRef.lean | 14 +++++++------- src/Init/System/IO.lean | 24 ++++++++++++------------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/Init/Control/StateRef.lean b/src/Init/Control/StateRef.lean index 5a0ee426b3..ed2814dfb6 100644 --- a/src/Init/Control/StateRef.lean +++ b/src/Init/Control/StateRef.lean @@ -11,13 +11,13 @@ import Init.Control.State def StateRefT (σ : Type) (m : Type → Type) (α : Type) : Type := ReaderT (IO.Ref σ) m α -@[inline] def StateRefT.run {σ : Type} {m : Type → Type} [Monad m] [MonadIO m] {α : Type} (x : StateRefT σ m α) (s : σ) : m (α × σ) := do +@[inline] def StateRefT.run {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (EIO Empty) m] {α : Type} (x : StateRefT σ m α) (s : σ) : m (α × σ) := do ref ← IO.mkRef s; a ← x ref; s ← ref.get; pure (a, s) -@[inline] def StateRefT.run' {σ : Type} {m : Type → Type} [Monad m] [MonadIO m] {α : Type} (x : StateRefT σ m α) (s : σ) : m α := do +@[inline] def StateRefT.run' {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (EIO Empty) m] {α : Type} (x : StateRefT σ m α) (s : σ) : m α := do (a, _) ← x.run s; pure a @@ -28,22 +28,22 @@ variables {σ : Type} {m : Type → Type} {α : Type} fun _ => x instance [Monad m] : Monad (StateRefT σ m) := inferInstanceAs (Monad (ReaderT _ _)) -instance [Monad m] [MonadIO m] : MonadIO (StateRefT σ m) := inferInstanceAs (MonadIO (ReaderT _ _)) instance : HasMonadLift m (StateRefT σ m) := ⟨fun _ => StateRefT.lift⟩ +instance [Monad m] [MonadIO m] : MonadIO (StateRefT σ m) := inferInstanceAs (MonadIO (ReaderT _ _)) instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateRefT σ m) (StateRefT σ m') := inferInstanceAs (MonadFunctor m m' (ReaderT _ _) (ReaderT _ _)) -@[inline] protected def get [Monad m] [MonadIO m] : StateRefT σ m σ := +@[inline] protected def get [Monad m] [HasMonadLift (EIO Empty) m] : StateRefT σ m σ := fun ref => ref.get -@[inline] protected def set [Monad m] [MonadIO m] (s : σ) : StateRefT σ m PUnit := +@[inline] protected def set [Monad m] [HasMonadLift (EIO Empty) m] (s : σ) : StateRefT σ m PUnit := fun ref => ref.set s -@[inline] protected def modifyGet [Monad m] [MonadIO m] (f : σ → α × σ) : StateRefT σ m α := +@[inline] protected def modifyGet [Monad m] [HasMonadLift (EIO Empty) m] (f : σ → α × σ) : StateRefT σ m α := fun ref => ref.modifyGet f -instance [Monad m] [MonadIO m] : MonadStateOf σ (StateRefT σ m) := +instance [Monad m] [HasMonadLift (EIO Empty) m] : MonadStateOf σ (StateRefT σ m) := { get := StateRefT.get, set := StateRefT.set, modifyGet := fun α f => StateRefT.modifyGet f } diff --git a/src/Init/System/IO.lean b/src/Init/System/IO.lean index a359418b71..d90895f639 100644 --- a/src/Init/System/IO.lean +++ b/src/Init/System/IO.lean @@ -387,24 +387,24 @@ end Prim section -@[inline] private def toIO {α} (x : EIO Empty α) : IO α := +@[inline] private def fromEmptyEIO {ε α} (x : EIO Empty α) : EIO ε α := fun s => match x s with | r@(EStateM.Result.error e _) => Empty.rec _ e | EStateM.Result.ok a s => EStateM.Result.ok a s -variables {m : Type → Type} [Monad m] [MonadIO m] +instance EIOEmpty.monadLift {ε} : HasMonadLift (EIO Empty) (EIO ε) := +{ monadLift := fun α => fromEmptyEIO } -@[inline] def liftEIOEmpty {α} (x : EIO Empty α) : m α := -liftIO $ toIO x +variables {m : Type → Type} [Monad m] [HasMonadLiftT (EIO Empty) m] -@[inline] def mkRef {α : Type} (a : α) : m (Ref α) := liftEIOEmpty $ Prim.mkRef a -@[inline] def Ref.get {α : Type} (r : Ref α) : m α := liftEIOEmpty $ Prim.Ref.get r -@[inline] def Ref.set {α : Type} (r : Ref α) (a : α) : m Unit := liftEIOEmpty $ Prim.Ref.set r a -@[inline] def Ref.swap {α : Type} (r : Ref α) (a : α) : m α := liftEIOEmpty $ Prim.Ref.swap r a -@[inline] unsafe def Ref.take {α : Type} (r : Ref α) : m α := liftEIOEmpty $ Prim.Ref.take r -@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref α) : m Bool := liftEIOEmpty $ Prim.Ref.ptrEq r1 r2 -@[inline] def Ref.modify {α : Type} (r : Ref α) (f : α → α) : m Unit := liftEIOEmpty $ Prim.Ref.modify r f -@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : m β := liftEIOEmpty $ Prim.Ref.modifyGet r f +@[inline] def mkRef {α : Type} (a : α) : m (Ref α) := liftM $ Prim.mkRef a +@[inline] def Ref.get {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.get r +@[inline] def Ref.set {α : Type} (r : Ref α) (a : α) : m Unit := liftM $ Prim.Ref.set r a +@[inline] def Ref.swap {α : Type} (r : Ref α) (a : α) : m α := liftM $ Prim.Ref.swap r a +@[inline] unsafe def Ref.take {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.take r +@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2 +@[inline] def Ref.modify {α : Type} (r : Ref α) (f : α → α) : m Unit := liftM $ Prim.Ref.modify r f +@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f end