feat: real ST monad
@Kha: the new `ST` (and `EST`) are escapable like the Haskell ST monad. It makes `StateRefT` much more useful because we can now run it from pure code.
This commit is contained in:
parent
a8f68f6360
commit
77b9445544
8 changed files with 153 additions and 100 deletions
|
|
@ -9,47 +9,47 @@ prelude
|
|||
import Init.System.IO
|
||||
import Init.Control.State
|
||||
|
||||
def StateRefT (σ : Type) (m : Type → Type) (α : Type) : Type := ReaderT (IO.Ref σ) m α
|
||||
def StateRefT' (ω : Type) (σ : Type) (m : Type → Type) (α : Type) : Type := ReaderT (ST.Ref ω σ) m α
|
||||
abbrev StateRefT {ω : Type} (σ : Type) (m : Type → Type) [STWorld ω m] (α : Type) := StateRefT' ω σ m α
|
||||
|
||||
@[inline] def StateRefT.run {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT ST m] {α : Type} (x : StateRefT σ m α) (s : σ) : m (α × σ) := do
|
||||
ref ← IO.mkRef s;
|
||||
@[inline] def StateRefT'.run {ω σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST ω) m] {α : Type} (x : StateRefT' ω σ m α) (s : σ) : m (α × σ) := do
|
||||
ref ← ST.mkRef s;
|
||||
a ← x ref;
|
||||
s ← ref.get;
|
||||
pure (a, s)
|
||||
|
||||
@[inline] def StateRefT.run' {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT ST m] {α : Type} (x : StateRefT σ m α) (s : σ) : m α := do
|
||||
@[inline] def StateRefT'.run' {ω σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST ω) m] {α : Type} (x : StateRefT' ω σ m α) (s : σ) : m α := do
|
||||
(a, _) ← x.run s;
|
||||
pure a
|
||||
|
||||
namespace StateRefT
|
||||
variables {σ : Type} {m : Type → Type} {α : Type}
|
||||
namespace StateRefT'
|
||||
variables {ω σ : Type} {m : Type → Type} {α : Type}
|
||||
|
||||
@[inline] protected def lift (x : m α) : StateRefT σ m α :=
|
||||
@[inline] protected def lift (x : m α) : StateRefT' ω σ m α :=
|
||||
fun _ => x
|
||||
|
||||
instance [Monad m] : Monad (StateRefT σ m) := inferInstanceAs (Monad (ReaderT _ _))
|
||||
instance : HasMonadLift m (StateRefT σ m) := ⟨fun _ => StateRefT.lift⟩
|
||||
instance [Monad m] [MonadIO m] : MonadIO (StateRefT σ m) := inferInstanceAs (MonadIO (ReaderT _ _))
|
||||
|
||||
instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateRefT σ m) (StateRefT σ m') :=
|
||||
instance [Monad m] : Monad (StateRefT' ω σ m) := inferInstanceAs (Monad (ReaderT _ _))
|
||||
instance : HasMonadLift m (StateRefT' ω σ m) := ⟨fun _ => StateRefT'.lift⟩
|
||||
instance [Monad m] [MonadIO m] : MonadIO (StateRefT' ω σ m) := inferInstanceAs (MonadIO (ReaderT _ _))
|
||||
instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateRefT' ω σ m) (StateRefT' ω σ m') :=
|
||||
inferInstanceAs (MonadFunctor m m' (ReaderT _ _) (ReaderT _ _))
|
||||
|
||||
@[inline] protected def get [Monad m] [HasMonadLiftT ST m] : StateRefT σ m σ :=
|
||||
@[inline] protected def get [Monad m] [HasMonadLiftT (ST ω) m] : StateRefT' ω σ m σ :=
|
||||
fun ref => ref.get
|
||||
|
||||
@[inline] protected def set [Monad m] [HasMonadLiftT ST m] (s : σ) : StateRefT σ m PUnit :=
|
||||
@[inline] protected def set [Monad m] [HasMonadLiftT (ST ω) m] (s : σ) : StateRefT' ω σ m PUnit :=
|
||||
fun ref => ref.set s
|
||||
|
||||
@[inline] protected def modifyGet [Monad m] [HasMonadLiftT ST m] (f : σ → α × σ) : StateRefT σ m α :=
|
||||
@[inline] protected def modifyGet [Monad m] [HasMonadLiftT (ST ω) m] (f : σ → α × σ) : StateRefT' ω σ m α :=
|
||||
fun ref => ref.modifyGet f
|
||||
|
||||
instance [HasMonadLiftT ST m] [Monad m] : MonadStateOf σ (StateRefT σ m) :=
|
||||
{ get := StateRefT.get,
|
||||
set := StateRefT.set,
|
||||
modifyGet := fun α f => StateRefT.modifyGet f }
|
||||
instance [HasMonadLiftT (ST ω) m] [Monad m] : MonadStateOf σ (StateRefT' ω σ m) :=
|
||||
{ get := StateRefT'.get,
|
||||
set := StateRefT'.set,
|
||||
modifyGet := fun α f => StateRefT'.modifyGet f }
|
||||
|
||||
instance (ε) [MonadExceptOf ε m] : MonadExceptOf ε (StateRefT σ m) :=
|
||||
{ throw := fun α => StateRefT.lift ∘ throwThe ε,
|
||||
instance (ε) [MonadExceptOf ε m] : MonadExceptOf ε (StateRefT' ω σ m) :=
|
||||
{ throw := fun α => StateRefT'.lift ∘ throwThe ε,
|
||||
catch := fun α x c s => catchThe ε (x s) (fun e => c e s) }
|
||||
|
||||
end StateRefT
|
||||
end StateRefT'
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import Init.Data.String.Basic
|
|||
import Init.Data.ByteArray
|
||||
import Init.System.IOError
|
||||
import Init.System.FilePath
|
||||
import Init.System.ST
|
||||
|
||||
/-- Like https://hackage.haskell.org/package/ghc-Prim-0.5.2.0/docs/GHC-Prim.html#t:RealWorld.
|
||||
Makes sure we never reorder `IO` operations.
|
||||
|
|
@ -26,8 +27,6 @@ def IO.RealWorld : Type := Unit
|
|||
-/
|
||||
def EIO (ε : Type) : Type → Type := EStateM ε IO.RealWorld
|
||||
|
||||
def ST := EIO Empty
|
||||
|
||||
instance monadExceptAdapter {ε ε'} : MonadExceptAdapter ε ε' (EIO ε) (EIO ε') :=
|
||||
inferInstanceAs $ MonadExceptAdapter ε ε' (EStateM ε IO.RealWorld) (EStateM ε' IO.RealWorld)
|
||||
|
||||
|
|
@ -41,7 +40,6 @@ instance (ε : Type) : MonadExceptOf ε (EIO ε) := inferInstanceAs (MonadExcept
|
|||
instance (α ε : Type) : HasOrelse (EIO ε α) := ⟨MonadExcept.orelse⟩
|
||||
instance {ε : Type} {α : Type} [Inhabited ε] : Inhabited (EIO ε α) :=
|
||||
inferInstanceAs (Inhabited (EStateM ε IO.RealWorld α))
|
||||
instance : Monad ST := inferInstanceAs (Monad (EIO Empty))
|
||||
|
||||
abbrev IO : Type → Type := EIO IO.Error
|
||||
|
||||
|
|
@ -323,79 +321,15 @@ def setAccessRights (filename : String) (mode : FileRight) : IO Unit :=
|
|||
Prim.setAccessRights filename mode.flags
|
||||
|
||||
/- References -/
|
||||
constant RefPointed : PointedType.{0} := arbitrary _
|
||||
abbrev Ref (α : Type) := ST.Ref IO.RealWorld α
|
||||
|
||||
structure Ref (α : Type) : Type :=
|
||||
(ref : RefPointed.type) (h : Nonempty α)
|
||||
instance st2eio {ε} : HasMonadLift (ST IO.RealWorld) (EIO ε) :=
|
||||
⟨fun α x s => match x s with
|
||||
| EStateM.Result.ok a s => EStateM.Result.ok a s
|
||||
| EStateM.Result.error ex _ => Empty.rec _ ex⟩
|
||||
|
||||
instance Ref.inhabited {α} [Inhabited α] : Inhabited (Ref α) :=
|
||||
⟨{ ref := RefPointed.val, h := Nonempty.intro $ arbitrary _}⟩
|
||||
|
||||
namespace Prim
|
||||
|
||||
/- Auxiliary definition for showing that `EIO ε α` is inhabited when we have a `Ref α` -/
|
||||
private noncomputable def inhabitedFromRef {α} (r : Ref α) : ST α :=
|
||||
pure $ (Classical.inhabitedOfNonempty r.h).default
|
||||
|
||||
|
||||
@[extern "lean_io_mk_ref"]
|
||||
constant mkRef {α} (a : α) : ST (Ref α) := pure { ref := RefPointed.val, h := Nonempty.intro a }
|
||||
@[extern "lean_io_ref_get"]
|
||||
constant Ref.get {α} (r : @& Ref α) : ST α := inhabitedFromRef r
|
||||
@[extern "lean_io_ref_set"]
|
||||
constant Ref.set {α} (r : @& Ref α) (a : α) : ST Unit := arbitrary _
|
||||
@[extern "lean_io_ref_swap"]
|
||||
constant Ref.swap {α} (r : @& Ref α) (a : α) : ST α := inhabitedFromRef r
|
||||
@[extern "lean_io_ref_take"]
|
||||
unsafe constant Ref.take {α} (r : @& Ref α) : ST α := inhabitedFromRef r
|
||||
@[extern "lean_io_ref_ptr_eq"]
|
||||
constant Ref.ptrEq {α} (r1 r2 : @& Ref α) : ST Bool := arbitrary _
|
||||
@[inline] unsafe def Ref.modifyUnsafe {α : Type} (r : Ref α) (f : α → α) : ST Unit := do
|
||||
v ← Ref.take r;
|
||||
Ref.set r (f v)
|
||||
|
||||
@[inline] unsafe def Ref.modifyGetUnsafe {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : ST β := do
|
||||
v ← Ref.take r;
|
||||
let (b, a) := f v;
|
||||
Ref.set r a;
|
||||
pure b
|
||||
|
||||
@[implementedBy Ref.modifyUnsafe]
|
||||
def Ref.modify {α : Type} (r : Ref α) (f : α → α) : ST Unit := do
|
||||
v ← Ref.get r;
|
||||
Ref.set r (f v)
|
||||
|
||||
@[implementedBy Ref.modifyGetUnsafe]
|
||||
def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : ST β := do
|
||||
v ← Ref.get r;
|
||||
let (b, a) := f v;
|
||||
Ref.set r a;
|
||||
pure b
|
||||
|
||||
end Prim
|
||||
|
||||
section
|
||||
|
||||
@[inline] private def liftST {ε α} (x : ST α) : EIO ε α :=
|
||||
fun s => match x s with
|
||||
| r@(EStateM.Result.error e _) => Empty.rec _ e
|
||||
| EStateM.Result.ok a s => EStateM.Result.ok a s
|
||||
|
||||
instance ST.monadLift {ε} : HasMonadLift ST (EIO ε) :=
|
||||
{ monadLift := fun α => liftST }
|
||||
|
||||
variables {m : Type → Type} [Monad m] [HasMonadLiftT ST m]
|
||||
|
||||
@[inline] def mkRef {α : Type} (a : α) : m (Ref α) := liftM $ Prim.mkRef a
|
||||
@[inline] def Ref.get {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.get r
|
||||
@[inline] def Ref.set {α : Type} (r : Ref α) (a : α) : m Unit := liftM $ Prim.Ref.set r a
|
||||
@[inline] def Ref.swap {α : Type} (r : Ref α) (a : α) : m α := liftM $ Prim.Ref.swap r a
|
||||
@[inline] unsafe def Ref.take {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.take r
|
||||
@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2
|
||||
@[inline] def Ref.modify {α : Type} (r : Ref α) (f : α → α) : m Unit := liftM $ Prim.Ref.modify r f
|
||||
@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f
|
||||
|
||||
end
|
||||
def mkRef {α : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST IO.RealWorld) m] (a : α) : m (IO.Ref α) :=
|
||||
ST.mkRef a
|
||||
|
||||
end IO
|
||||
|
||||
|
|
|
|||
107
src/Init/System/ST.lean
Normal file
107
src/Init/System/ST.lean
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Control.EState
|
||||
import Init.Control.Reader
|
||||
|
||||
def EST (ε : Type) (σ : Type) : Type → Type := EStateM ε σ
|
||||
abbrev ST (σ : Type) := EST Empty σ
|
||||
|
||||
instance (ε σ : Type) : Monad (EST ε σ) := inferInstanceAs (Monad (EStateM _ _))
|
||||
instance (ε σ : Type) : MonadExceptOf ε (EST ε σ) := inferInstanceAs (MonadExceptOf ε (EStateM _ _))
|
||||
instance {ε σ : Type} {α : Type} [Inhabited ε] : Inhabited (EST ε σ α) := inferInstanceAs (Inhabited (EStateM _ _ _))
|
||||
instance (σ : Type) : Monad (ST σ) := inferInstanceAs (Monad (EST _ _))
|
||||
|
||||
-- Auxiliary class for inferring the "state" of `EST` and `ST` monads
|
||||
class STWorld (σ : outParam Type) (m : Type → Type)
|
||||
|
||||
instance STWorld.trans {σ m n} [STWorld σ m] [HasMonadLift m n] : STWorld σ n := ⟨⟩
|
||||
instance STWorld.base {ε σ} : STWorld σ (EST ε σ) := ⟨⟩
|
||||
|
||||
def runEST {ε α : Type} (x : forall (σ : Type), EST ε σ α) : Except ε α :=
|
||||
match x Unit () with
|
||||
| EStateM.Result.ok a _ => Except.ok a
|
||||
| EStateM.Result.error ex _ => Except.error ex
|
||||
|
||||
def runST {α : Type} (x : forall (σ : Type), ST σ α) : α :=
|
||||
match x Unit () with
|
||||
| EStateM.Result.ok a _ => a
|
||||
| EStateM.Result.error ex _ => Empty.rec _ ex
|
||||
|
||||
instance st2est {ε σ} : HasMonadLift (ST σ) (EST ε σ) :=
|
||||
⟨fun α x s => match x s with
|
||||
| EStateM.Result.ok a s => EStateM.Result.ok a s
|
||||
| EStateM.Result.error ex _ => Empty.rec _ ex⟩
|
||||
|
||||
namespace ST
|
||||
|
||||
/- References -/
|
||||
constant RefPointed : PointedType.{0} := arbitrary _
|
||||
|
||||
structure Ref (σ : Type) (α : Type) : Type :=
|
||||
(ref : RefPointed.type) (h : Nonempty α)
|
||||
|
||||
instance Ref.inhabited {σ α} [Inhabited α] : Inhabited (Ref σ α) :=
|
||||
⟨{ ref := RefPointed.val, h := Nonempty.intro $ arbitrary _}⟩
|
||||
|
||||
namespace Prim
|
||||
|
||||
/- Auxiliary definition for showing that `ST σ α` is inhabited when we have a `Ref σ α` -/
|
||||
private noncomputable def inhabitedFromRef {σ α} (r : Ref σ α) : ST σ α :=
|
||||
pure $ (Classical.inhabitedOfNonempty r.h).default
|
||||
|
||||
@[extern "lean_io_mk_ref"]
|
||||
constant mkRef {σ α} (a : α) : ST σ (Ref σ α) := pure { ref := RefPointed.val, h := Nonempty.intro a }
|
||||
@[extern "lean_io_ref_get"]
|
||||
constant Ref.get {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
|
||||
@[extern "lean_io_ref_set"]
|
||||
constant Ref.set {σ α} (r : @& Ref σ α) (a : α) : ST σ Unit := arbitrary _
|
||||
@[extern "lean_io_ref_swap"]
|
||||
constant Ref.swap {σ α} (r : @& Ref σ α) (a : α) : ST σ α := inhabitedFromRef r
|
||||
@[extern "lean_io_ref_take"]
|
||||
unsafe constant Ref.take {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
|
||||
@[extern "lean_io_ref_ptr_eq"]
|
||||
constant Ref.ptrEq {σ α} (r1 r2 : @& Ref σ α) : ST σ Bool := arbitrary _
|
||||
|
||||
@[inline] unsafe def Ref.modifyUnsafe {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do
|
||||
v ← Ref.take r;
|
||||
Ref.set r (f v)
|
||||
|
||||
@[inline] unsafe def Ref.modifyGetUnsafe {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
|
||||
v ← Ref.take r;
|
||||
let (b, a) := f v;
|
||||
Ref.set r a;
|
||||
pure b
|
||||
|
||||
@[implementedBy Ref.modifyUnsafe]
|
||||
def Ref.modify {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do
|
||||
v ← Ref.get r;
|
||||
Ref.set r (f v)
|
||||
|
||||
@[implementedBy Ref.modifyGetUnsafe]
|
||||
def Ref.modifyGet {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
|
||||
v ← Ref.get r;
|
||||
let (b, a) := f v;
|
||||
Ref.set r a;
|
||||
pure b
|
||||
|
||||
end Prim
|
||||
|
||||
section
|
||||
variables {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST σ) m]
|
||||
|
||||
@[inline] def mkRef {α : Type} (a : α) : m (Ref σ α) := liftM $ Prim.mkRef a
|
||||
@[inline] def Ref.get {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.get r
|
||||
@[inline] def Ref.set {α : Type} (r : Ref σ α) (a : α) : m Unit := liftM $ Prim.Ref.set r a
|
||||
@[inline] def Ref.swap {α : Type} (r : Ref σ α) (a : α) : m α := liftM $ Prim.Ref.swap r a
|
||||
@[inline] unsafe def Ref.take {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.take r
|
||||
@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref σ α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2
|
||||
@[inline] def Ref.modify {α : Type} (r : Ref σ α) (f : α → α) : m Unit := liftM $ Prim.Ref.modify r f
|
||||
@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref σ α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f
|
||||
|
||||
end
|
||||
|
||||
end ST
|
||||
|
|
@ -100,7 +100,7 @@ def monadOptsFromLift (m) {n} [MonadOptions m] [HasMonadLiftT m n] : MonadOption
|
|||
{ getOptions := liftM (getOptions : m _) }
|
||||
|
||||
instance ReaderT.monadOpts {ρ m} [MonadOptions m] : MonadOptions (ReaderT ρ m) := monadOptsFromLift m
|
||||
instance StateRefT.monadOpts {σ m} [MonadOptions m] : MonadOptions (StateRefT σ m) := monadOptsFromLift m
|
||||
instance StateRefT.monadOpts {ω σ m} [MonadOptions m] : MonadOptions (StateRefT' ω σ m) := monadOptsFromLift m
|
||||
|
||||
section Methods
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ instance ReaderT.monadError {ρ m} [Monad m] [MonadError m] : MonadError (Reader
|
|||
{ getRef := fun _ => getRef,
|
||||
addContext := fun ref msg _ => addContext ref msg }
|
||||
|
||||
instance StateRefT.monadError {σ m} [Monad m] [MonadError m] : MonadError (StateRefT σ m) :=
|
||||
instance StateRefT.monadError {ω σ m} [Monad m] [MonadError m] : MonadError (StateRefT' ω σ m) :=
|
||||
inferInstanceAs (MonadError (ReaderT _ _))
|
||||
|
||||
section Methods
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def monadEnvFromLift (m) {n} [MonadEnv m] [HasMonadLiftT m n] : MonadEnv n :=
|
|||
modifyEnv := fun f => liftM (modifyEnv f : m Unit) }
|
||||
|
||||
instance ReaderT.monadEnv {m ρ} [Monad m] [MonadEnv m] : MonadEnv (ReaderT ρ m) := monadEnvFromLift m
|
||||
instance StateRefT.monadEnv {m σ} [MonadEnv m] : MonadEnv (StateRefT σ m) := monadEnvFromLift m
|
||||
instance StateRefT.monadEnv {ω m σ} [MonadEnv m] : MonadEnv (StateRefT' ω σ m) := monadEnvFromLift m
|
||||
instance OptionT.monadEnv {m} [Monad m] [MonadEnv m] : MonadEnv (OptionT m) := monadEnvFromLift m
|
||||
|
||||
section Methods
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ instance ReaderT.monadTracer (ρ : Type) (m : Type → Type) [MonadTracer m] : M
|
|||
trace := fun n x _ => MonadTracer.trace n x,
|
||||
traceM := fun n x ctx => MonadTracer.traceM n (x ctx) }
|
||||
|
||||
instance StateRefT.monadTracer (σ : Type) (m : Type → Type) [MonadTracer m] : MonadTracer (StateRefT σ m) :=
|
||||
instance StateRefT.monadTracer (ω σ : Type) (m : Type → Type) [MonadTracer m] : MonadTracer (StateRefT' ω σ m) :=
|
||||
inferInstanceAs (MonadTracer (ReaderT _ _))
|
||||
|
||||
class MonadTracerAdapter (m : Type → Type) :=
|
||||
|
|
|
|||
|
|
@ -52,3 +52,15 @@ IO.println $ "state1 " ++ toString a2;
|
|||
pure (a0 + a1 + a2)
|
||||
|
||||
#eval ((f4.run' ⟨10⟩).run' ⟨20⟩).run' ⟨30⟩
|
||||
|
||||
abbrev S (ω : Type) := StateRefT Nat $ StateRefT String $ ST ω
|
||||
|
||||
def f5 {ω} : S ω Unit := do
|
||||
s ← getThe String;
|
||||
modify fun n => n + s.length;
|
||||
pure ()
|
||||
|
||||
def f5Pure (n : Nat) (s : String) :=
|
||||
runST (fun _ => (f5.run n).run s)
|
||||
|
||||
#eval f5Pure 10 "hello world"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue