feat: make sure StateRefT can be used with any base EIO Exception

This commit is contained in:
Leonardo de Moura 2020-08-20 13:11:12 -07:00
parent d36ccb166c
commit 968457ac1c
2 changed files with 19 additions and 19 deletions

View file

@ -11,13 +11,13 @@ import Init.Control.State
def StateRefT (σ : Type) (m : Type → Type) (α : Type) : Type := ReaderT (IO.Ref σ) m α
@[inline] def StateRefT.run {σ : Type} {m : Type → Type} [Monad m] [MonadIO m] {α : Type} (x : StateRefT σ m α) (s : σ) : m (α × σ) := do
@[inline] def StateRefT.run {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (EIO Empty) m] {α : Type} (x : StateRefT σ m α) (s : σ) : m (α × σ) := do
ref ← IO.mkRef s;
a ← x ref;
s ← ref.get;
pure (a, s)
@[inline] def StateRefT.run' {σ : Type} {m : Type → Type} [Monad m] [MonadIO m] {α : Type} (x : StateRefT σ m α) (s : σ) : m α := do
@[inline] def StateRefT.run' {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (EIO Empty) m] {α : Type} (x : StateRefT σ m α) (s : σ) : m α := do
(a, _) ← x.run s;
pure a
@ -28,22 +28,22 @@ variables {σ : Type} {m : Type → Type} {α : Type}
fun _ => x
instance [Monad m] : Monad (StateRefT σ m) := inferInstanceAs (Monad (ReaderT _ _))
instance [Monad m] [MonadIO m] : MonadIO (StateRefT σ m) := inferInstanceAs (MonadIO (ReaderT _ _))
instance : HasMonadLift m (StateRefT σ m) := ⟨fun _ => StateRefT.lift⟩
instance [Monad m] [MonadIO m] : MonadIO (StateRefT σ m) := inferInstanceAs (MonadIO (ReaderT _ _))
instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateRefT σ m) (StateRefT σ m') :=
inferInstanceAs (MonadFunctor m m' (ReaderT _ _) (ReaderT _ _))
@[inline] protected def get [Monad m] [MonadIO m] : StateRefT σ m σ :=
@[inline] protected def get [Monad m] [HasMonadLift (EIO Empty) m] : StateRefT σ m σ :=
fun ref => ref.get
@[inline] protected def set [Monad m] [MonadIO m] (s : σ) : StateRefT σ m PUnit :=
@[inline] protected def set [Monad m] [HasMonadLift (EIO Empty) m] (s : σ) : StateRefT σ m PUnit :=
fun ref => ref.set s
@[inline] protected def modifyGet [Monad m] [MonadIO m] (f : σα × σ) : StateRefT σ m α :=
@[inline] protected def modifyGet [Monad m] [HasMonadLift (EIO Empty) m] (f : σα × σ) : StateRefT σ m α :=
fun ref => ref.modifyGet f
instance [Monad m] [MonadIO m] : MonadStateOf σ (StateRefT σ m) :=
instance [Monad m] [HasMonadLift (EIO Empty) m] : MonadStateOf σ (StateRefT σ m) :=
{ get := StateRefT.get,
set := StateRefT.set,
modifyGet := fun α f => StateRefT.modifyGet f }

View file

@ -387,24 +387,24 @@ end Prim
section
@[inline] private def toIO {α} (x : EIO Empty α) : IO α :=
@[inline] private def fromEmptyEIO {ε α} (x : EIO Empty α) : EIO ε α :=
fun s => match x s with
| r@(EStateM.Result.error e _) => Empty.rec _ e
| EStateM.Result.ok a s => EStateM.Result.ok a s
variables {m : Type → Type} [Monad m] [MonadIO m]
instance EIOEmpty.monadLift {ε} : HasMonadLift (EIO Empty) (EIO ε) :=
{ monadLift := fun α => fromEmptyEIO }
@[inline] def liftEIOEmpty {α} (x : EIO Empty α) : m α :=
liftIO $ toIO x
variables {m : Type → Type} [Monad m] [HasMonadLiftT (EIO Empty) m]
@[inline] def mkRef {α : Type} (a : α) : m (Ref α) := liftEIOEmpty $ Prim.mkRef a
@[inline] def Ref.get {α : Type} (r : Ref α) : m α := liftEIOEmpty $ Prim.Ref.get r
@[inline] def Ref.set {α : Type} (r : Ref α) (a : α) : m Unit := liftEIOEmpty $ Prim.Ref.set r a
@[inline] def Ref.swap {α : Type} (r : Ref α) (a : α) : m α := liftEIOEmpty $ Prim.Ref.swap r a
@[inline] unsafe def Ref.take {α : Type} (r : Ref α) : m α := liftEIOEmpty $ Prim.Ref.take r
@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref α) : m Bool := liftEIOEmpty $ Prim.Ref.ptrEq r1 r2
@[inline] def Ref.modify {α : Type} (r : Ref α) (f : αα) : m Unit := liftEIOEmpty $ Prim.Ref.modify r f
@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : m β := liftEIOEmpty $ Prim.Ref.modifyGet r f
@[inline] def mkRef {α : Type} (a : α) : m (Ref α) := liftM $ Prim.mkRef a
@[inline] def Ref.get {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.get r
@[inline] def Ref.set {α : Type} (r : Ref α) (a : α) : m Unit := liftM $ Prim.Ref.set r a
@[inline] def Ref.swap {α : Type} (r : Ref α) (a : α) : m α := liftM $ Prim.Ref.swap r a
@[inline] unsafe def Ref.take {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.take r
@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2
@[inline] def Ref.modify {α : Type} (r : Ref α) (f : αα) : m Unit := liftM $ Prim.Ref.modify r f
@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f
end