feat: real ST monad

@Kha: the new `ST` (and `EST`) are escapable like the Haskell ST monad.
It makes `StateRefT` much more useful because we can now run it from pure
code.
This commit is contained in:
Leonardo de Moura 2020-08-23 12:08:13 -07:00
parent a8f68f6360
commit 77b9445544
8 changed files with 153 additions and 100 deletions

View file

@ -9,47 +9,47 @@ prelude
import Init.System.IO
import Init.Control.State
def StateRefT (σ : Type) (m : Type → Type) (α : Type) : Type := ReaderT (IO.Ref σ) m α
def StateRefT' (ω : Type) (σ : Type) (m : Type → Type) (α : Type) : Type := ReaderT (ST.Ref ω σ) m α
abbrev StateRefT {ω : Type} (σ : Type) (m : Type → Type) [STWorld ω m] (α : Type) := StateRefT' ω σ m α
@[inline] def StateRefT.run {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT ST m] {α : Type} (x : StateRefT σ m α) (s : σ) : m (α × σ) := do
ref ← IO.mkRef s;
@[inline] def StateRefT'.run {ω σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST ω) m] {α : Type} (x : StateRefT' ω σ m α) (s : σ) : m (α × σ) := do
ref ← ST.mkRef s;
a ← x ref;
s ← ref.get;
pure (a, s)
@[inline] def StateRefT.run' {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT ST m] {α : Type} (x : StateRefT σ m α) (s : σ) : m α := do
@[inline] def StateRefT'.run' {ω σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST ω) m] {α : Type} (x : StateRefT' ω σ m α) (s : σ) : m α := do
(a, _) ← x.run s;
pure a
namespace StateRefT
variables {σ : Type} {m : Type → Type} {α : Type}
namespace StateRefT'
variables {ω σ : Type} {m : Type → Type} {α : Type}
@[inline] protected def lift (x : m α) : StateRefT σ m α :=
@[inline] protected def lift (x : m α) : StateRefT' ω σ m α :=
fun _ => x
instance [Monad m] : Monad (StateRefT σ m) := inferInstanceAs (Monad (ReaderT _ _))
instance : HasMonadLift m (StateRefT σ m) := ⟨fun _ => StateRefT.lift⟩
instance [Monad m] [MonadIO m] : MonadIO (StateRefT σ m) := inferInstanceAs (MonadIO (ReaderT _ _))
instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateRefT σ m) (StateRefT σ m') :=
instance [Monad m] : Monad (StateRefT' ω σ m) := inferInstanceAs (Monad (ReaderT _ _))
instance : HasMonadLift m (StateRefT' ω σ m) := ⟨fun _ => StateRefT'.lift⟩
instance [Monad m] [MonadIO m] : MonadIO (StateRefT' ω σ m) := inferInstanceAs (MonadIO (ReaderT _ _))
instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateRefT' ω σ m) (StateRefT' ω σ m') :=
inferInstanceAs (MonadFunctor m m' (ReaderT _ _) (ReaderT _ _))
@[inline] protected def get [Monad m] [HasMonadLiftT ST m] : StateRefT σ m σ :=
@[inline] protected def get [Monad m] [HasMonadLiftT (ST ω) m] : StateRefT' ω σ m σ :=
fun ref => ref.get
@[inline] protected def set [Monad m] [HasMonadLiftT ST m] (s : σ) : StateRefT σ m PUnit :=
@[inline] protected def set [Monad m] [HasMonadLiftT (ST ω) m] (s : σ) : StateRefT' ω σ m PUnit :=
fun ref => ref.set s
@[inline] protected def modifyGet [Monad m] [HasMonadLiftT ST m] (f : σα × σ) : StateRefT σ m α :=
@[inline] protected def modifyGet [Monad m] [HasMonadLiftT (ST ω) m] (f : σα × σ) : StateRefT' ω σ m α :=
fun ref => ref.modifyGet f
instance [HasMonadLiftT ST m] [Monad m] : MonadStateOf σ (StateRefT σ m) :=
{ get := StateRefT.get,
set := StateRefT.set,
modifyGet := fun α f => StateRefT.modifyGet f }
instance [HasMonadLiftT (ST ω) m] [Monad m] : MonadStateOf σ (StateRefT' ω σ m) :=
{ get := StateRefT'.get,
set := StateRefT'.set,
modifyGet := fun α f => StateRefT'.modifyGet f }
instance (ε) [MonadExceptOf ε m] : MonadExceptOf ε (StateRefT σ m) :=
{ throw := fun α => StateRefT.lift ∘ throwThe ε,
instance (ε) [MonadExceptOf ε m] : MonadExceptOf ε (StateRefT' ω σ m) :=
{ throw := fun α => StateRefT'.lift ∘ throwThe ε,
catch := fun α x c s => catchThe ε (x s) (fun e => c e s) }
end StateRefT
end StateRefT'

View file

@ -10,6 +10,7 @@ import Init.Data.String.Basic
import Init.Data.ByteArray
import Init.System.IOError
import Init.System.FilePath
import Init.System.ST
/-- Like https://hackage.haskell.org/package/ghc-Prim-0.5.2.0/docs/GHC-Prim.html#t:RealWorld.
Makes sure we never reorder `IO` operations.
@ -26,8 +27,6 @@ def IO.RealWorld : Type := Unit
-/
def EIO (ε : Type) : Type → Type := EStateM ε IO.RealWorld
def ST := EIO Empty
instance monadExceptAdapter {ε ε'} : MonadExceptAdapter ε ε' (EIO ε) (EIO ε') :=
inferInstanceAs $ MonadExceptAdapter ε ε' (EStateM ε IO.RealWorld) (EStateM ε' IO.RealWorld)
@ -41,7 +40,6 @@ instance (ε : Type) : MonadExceptOf ε (EIO ε) := inferInstanceAs (MonadExcept
instance (α ε : Type) : HasOrelse (EIO ε α) := ⟨MonadExcept.orelse⟩
instance {ε : Type} {α : Type} [Inhabited ε] : Inhabited (EIO ε α) :=
inferInstanceAs (Inhabited (EStateM ε IO.RealWorld α))
instance : Monad ST := inferInstanceAs (Monad (EIO Empty))
abbrev IO : Type → Type := EIO IO.Error
@ -323,79 +321,15 @@ def setAccessRights (filename : String) (mode : FileRight) : IO Unit :=
Prim.setAccessRights filename mode.flags
/- References -/
constant RefPointed : PointedType.{0} := arbitrary _
abbrev Ref (α : Type) := ST.Ref IO.RealWorld α
structure Ref (α : Type) : Type :=
(ref : RefPointed.type) (h : Nonempty α)
instance st2eio {ε} : HasMonadLift (ST IO.RealWorld) (EIO ε) :=
⟨fun α x s => match x s with
| EStateM.Result.ok a s => EStateM.Result.ok a s
| EStateM.Result.error ex _ => Empty.rec _ ex⟩
instance Ref.inhabited {α} [Inhabited α] : Inhabited (Ref α) :=
⟨{ ref := RefPointed.val, h := Nonempty.intro $ arbitrary _}⟩
namespace Prim
/- Auxiliary definition for showing that `EIO ε α` is inhabited when we have a `Ref α` -/
private noncomputable def inhabitedFromRef {α} (r : Ref α) : ST α :=
pure $ (Classical.inhabitedOfNonempty r.h).default
@[extern "lean_io_mk_ref"]
constant mkRef {α} (a : α) : ST (Ref α) := pure { ref := RefPointed.val, h := Nonempty.intro a }
@[extern "lean_io_ref_get"]
constant Ref.get {α} (r : @& Ref α) : ST α := inhabitedFromRef r
@[extern "lean_io_ref_set"]
constant Ref.set {α} (r : @& Ref α) (a : α) : ST Unit := arbitrary _
@[extern "lean_io_ref_swap"]
constant Ref.swap {α} (r : @& Ref α) (a : α) : ST α := inhabitedFromRef r
@[extern "lean_io_ref_take"]
unsafe constant Ref.take {α} (r : @& Ref α) : ST α := inhabitedFromRef r
@[extern "lean_io_ref_ptr_eq"]
constant Ref.ptrEq {α} (r1 r2 : @& Ref α) : ST Bool := arbitrary _
@[inline] unsafe def Ref.modifyUnsafe {α : Type} (r : Ref α) (f : αα) : ST Unit := do
v ← Ref.take r;
Ref.set r (f v)
@[inline] unsafe def Ref.modifyGetUnsafe {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : ST β := do
v ← Ref.take r;
let (b, a) := f v;
Ref.set r a;
pure b
@[implementedBy Ref.modifyUnsafe]
def Ref.modify {α : Type} (r : Ref α) (f : αα) : ST Unit := do
v ← Ref.get r;
Ref.set r (f v)
@[implementedBy Ref.modifyGetUnsafe]
def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : ST β := do
v ← Ref.get r;
let (b, a) := f v;
Ref.set r a;
pure b
end Prim
section
@[inline] private def liftST {ε α} (x : ST α) : EIO ε α :=
fun s => match x s with
| r@(EStateM.Result.error e _) => Empty.rec _ e
| EStateM.Result.ok a s => EStateM.Result.ok a s
instance ST.monadLift {ε} : HasMonadLift ST (EIO ε) :=
{ monadLift := fun α => liftST }
variables {m : Type → Type} [Monad m] [HasMonadLiftT ST m]
@[inline] def mkRef {α : Type} (a : α) : m (Ref α) := liftM $ Prim.mkRef a
@[inline] def Ref.get {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.get r
@[inline] def Ref.set {α : Type} (r : Ref α) (a : α) : m Unit := liftM $ Prim.Ref.set r a
@[inline] def Ref.swap {α : Type} (r : Ref α) (a : α) : m α := liftM $ Prim.Ref.swap r a
@[inline] unsafe def Ref.take {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.take r
@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2
@[inline] def Ref.modify {α : Type} (r : Ref α) (f : αα) : m Unit := liftM $ Prim.Ref.modify r f
@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f
end
def mkRef {α : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST IO.RealWorld) m] (a : α) : m (IO.Ref α) :=
ST.mkRef a
end IO

107
src/Init/System/ST.lean Normal file
View file

@ -0,0 +1,107 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Control.EState
import Init.Control.Reader
def EST (ε : Type) (σ : Type) : Type → Type := EStateM ε σ
abbrev ST (σ : Type) := EST Empty σ
instance (ε σ : Type) : Monad (EST ε σ) := inferInstanceAs (Monad (EStateM _ _))
instance (ε σ : Type) : MonadExceptOf ε (EST ε σ) := inferInstanceAs (MonadExceptOf ε (EStateM _ _))
instance {ε σ : Type} {α : Type} [Inhabited ε] : Inhabited (EST ε σ α) := inferInstanceAs (Inhabited (EStateM _ _ _))
instance (σ : Type) : Monad (ST σ) := inferInstanceAs (Monad (EST _ _))
-- Auxiliary class for inferring the "state" of `EST` and `ST` monads
class STWorld (σ : outParam Type) (m : Type → Type)
instance STWorld.trans {σ m n} [STWorld σ m] [HasMonadLift m n] : STWorld σ n := ⟨⟩
instance STWorld.base {ε σ} : STWorld σ (EST ε σ) := ⟨⟩
def runEST {ε α : Type} (x : forall (σ : Type), EST ε σ α) : Except ε α :=
match x Unit () with
| EStateM.Result.ok a _ => Except.ok a
| EStateM.Result.error ex _ => Except.error ex
def runST {α : Type} (x : forall (σ : Type), ST σ α) : α :=
match x Unit () with
| EStateM.Result.ok a _ => a
| EStateM.Result.error ex _ => Empty.rec _ ex
instance st2est {ε σ} : HasMonadLift (ST σ) (EST ε σ) :=
⟨fun α x s => match x s with
| EStateM.Result.ok a s => EStateM.Result.ok a s
| EStateM.Result.error ex _ => Empty.rec _ ex⟩
namespace ST
/- References -/
constant RefPointed : PointedType.{0} := arbitrary _
structure Ref (σ : Type) (α : Type) : Type :=
(ref : RefPointed.type) (h : Nonempty α)
instance Ref.inhabited {σ α} [Inhabited α] : Inhabited (Ref σ α) :=
⟨{ ref := RefPointed.val, h := Nonempty.intro $ arbitrary _}⟩
namespace Prim
/- Auxiliary definition for showing that `ST σ α` is inhabited when we have a `Ref σ α` -/
private noncomputable def inhabitedFromRef {σ α} (r : Ref σ α) : ST σ α :=
pure $ (Classical.inhabitedOfNonempty r.h).default
@[extern "lean_io_mk_ref"]
constant mkRef {σ α} (a : α) : ST σ (Ref σ α) := pure { ref := RefPointed.val, h := Nonempty.intro a }
@[extern "lean_io_ref_get"]
constant Ref.get {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
@[extern "lean_io_ref_set"]
constant Ref.set {σ α} (r : @& Ref σ α) (a : α) : ST σ Unit := arbitrary _
@[extern "lean_io_ref_swap"]
constant Ref.swap {σ α} (r : @& Ref σ α) (a : α) : ST σ α := inhabitedFromRef r
@[extern "lean_io_ref_take"]
unsafe constant Ref.take {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
@[extern "lean_io_ref_ptr_eq"]
constant Ref.ptrEq {σ α} (r1 r2 : @& Ref σ α) : ST σ Bool := arbitrary _
@[inline] unsafe def Ref.modifyUnsafe {σ α : Type} (r : Ref σ α) (f : αα) : ST σ Unit := do
v ← Ref.take r;
Ref.set r (f v)
@[inline] unsafe def Ref.modifyGetUnsafe {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
v ← Ref.take r;
let (b, a) := f v;
Ref.set r a;
pure b
@[implementedBy Ref.modifyUnsafe]
def Ref.modify {σ α : Type} (r : Ref σ α) (f : αα) : ST σ Unit := do
v ← Ref.get r;
Ref.set r (f v)
@[implementedBy Ref.modifyGetUnsafe]
def Ref.modifyGet {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
v ← Ref.get r;
let (b, a) := f v;
Ref.set r a;
pure b
end Prim
section
variables {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST σ) m]
@[inline] def mkRef {α : Type} (a : α) : m (Ref σ α) := liftM $ Prim.mkRef a
@[inline] def Ref.get {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.get r
@[inline] def Ref.set {α : Type} (r : Ref σ α) (a : α) : m Unit := liftM $ Prim.Ref.set r a
@[inline] def Ref.swap {α : Type} (r : Ref σ α) (a : α) : m α := liftM $ Prim.Ref.swap r a
@[inline] unsafe def Ref.take {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.take r
@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref σ α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2
@[inline] def Ref.modify {α : Type} (r : Ref σ α) (f : αα) : m Unit := liftM $ Prim.Ref.modify r f
@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref σ α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f
end
end ST

View file

@ -100,7 +100,7 @@ def monadOptsFromLift (m) {n} [MonadOptions m] [HasMonadLiftT m n] : MonadOption
{ getOptions := liftM (getOptions : m _) }
instance ReaderT.monadOpts {ρ m} [MonadOptions m] : MonadOptions (ReaderT ρ m) := monadOptsFromLift m
instance StateRefT.monadOpts {σ m} [MonadOptions m] : MonadOptions (StateRefT σ m) := monadOptsFromLift m
instance StateRefT.monadOpts {ω σ m} [MonadOptions m] : MonadOptions (StateRefT' ω σ m) := monadOptsFromLift m
section Methods

View file

@ -34,7 +34,7 @@ instance ReaderT.monadError {ρ m} [Monad m] [MonadError m] : MonadError (Reader
{ getRef := fun _ => getRef,
addContext := fun ref msg _ => addContext ref msg }
instance StateRefT.monadError {σ m} [Monad m] [MonadError m] : MonadError (StateRefT σ m) :=
instance StateRefT.monadError {ω σ m} [Monad m] [MonadError m] : MonadError (StateRefT' ω σ m) :=
inferInstanceAs (MonadError (ReaderT _ _))
section Methods

View file

@ -21,7 +21,7 @@ def monadEnvFromLift (m) {n} [MonadEnv m] [HasMonadLiftT m n] : MonadEnv n :=
modifyEnv := fun f => liftM (modifyEnv f : m Unit) }
instance ReaderT.monadEnv {m ρ} [Monad m] [MonadEnv m] : MonadEnv (ReaderT ρ m) := monadEnvFromLift m
instance StateRefT.monadEnv {m σ} [MonadEnv m] : MonadEnv (StateRefT σ m) := monadEnvFromLift m
instance StateRefT.monadEnv {ω m σ} [MonadEnv m] : MonadEnv (StateRefT' ω σ m) := monadEnvFromLift m
instance OptionT.monadEnv {m} [Monad m] [MonadEnv m] : MonadEnv (OptionT m) := monadEnvFromLift m
section Methods

View file

@ -21,7 +21,7 @@ instance ReaderT.monadTracer (ρ : Type) (m : Type → Type) [MonadTracer m] : M
trace := fun n x _ => MonadTracer.trace n x,
traceM := fun n x ctx => MonadTracer.traceM n (x ctx) }
instance StateRefT.monadTracer (σ : Type) (m : Type → Type) [MonadTracer m] : MonadTracer (StateRefT σ m) :=
instance StateRefT.monadTracer (ω σ : Type) (m : Type → Type) [MonadTracer m] : MonadTracer (StateRefT' ω σ m) :=
inferInstanceAs (MonadTracer (ReaderT _ _))
class MonadTracerAdapter (m : Type → Type) :=

View file

@ -52,3 +52,15 @@ IO.println $ "state1 " ++ toString a2;
pure (a0 + a1 + a2)
#eval ((f4.run' ⟨10⟩).run' ⟨20⟩).run' ⟨30⟩
abbrev S (ω : Type) := StateRefT Nat $ StateRefT String $ ST ω
def f5 {ω} : S ω Unit := do
s ← getThe String;
modify fun n => n + s.length;
pure ()
def f5Pure (n : Nat) (s : String) :=
runST (fun _ => (f5.run n).run s)
#eval f5Pure 10 "hello world"