From 77b9445544e0ee17f4f954f751a3e6a54acbb2dd Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 23 Aug 2020 12:08:13 -0700 Subject: [PATCH] feat: real ST monad @Kha: the new `ST` (and `EST`) are escapable like the Haskell ST monad. It makes `StateRefT` much more useful because we can now run it from pure code. --- src/Init/Control/StateRef.lean | 44 +++++++------- src/Init/System/IO.lean | 82 +++---------------------- src/Init/System/ST.lean | 107 +++++++++++++++++++++++++++++++++ src/Lean/Data/Options.lean | 2 +- src/Lean/Exception.lean | 2 +- src/Lean/MonadEnv.lean | 2 +- src/Lean/Util/Trace.lean | 2 +- tests/lean/run/stateRef.lean | 12 ++++ 8 files changed, 153 insertions(+), 100 deletions(-) create mode 100644 src/Init/System/ST.lean diff --git a/src/Init/Control/StateRef.lean b/src/Init/Control/StateRef.lean index 961b760502..077a713e57 100644 --- a/src/Init/Control/StateRef.lean +++ b/src/Init/Control/StateRef.lean @@ -9,47 +9,47 @@ prelude import Init.System.IO import Init.Control.State -def StateRefT (σ : Type) (m : Type → Type) (α : Type) : Type := ReaderT (IO.Ref σ) m α +def StateRefT' (ω : Type) (σ : Type) (m : Type → Type) (α : Type) : Type := ReaderT (ST.Ref ω σ) m α +abbrev StateRefT {ω : Type} (σ : Type) (m : Type → Type) [STWorld ω m] (α : Type) := StateRefT' ω σ m α -@[inline] def StateRefT.run {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT ST m] {α : Type} (x : StateRefT σ m α) (s : σ) : m (α × σ) := do -ref ← IO.mkRef s; +@[inline] def StateRefT'.run {ω σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST ω) m] {α : Type} (x : StateRefT' ω σ m α) (s : σ) : m (α × σ) := do +ref ← ST.mkRef s; a ← x ref; s ← ref.get; pure (a, s) -@[inline] def StateRefT.run' {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT ST m] {α : Type} (x : StateRefT σ m α) (s : σ) : m α := do +@[inline] def StateRefT'.run' {ω σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST ω) m] {α : Type} (x : StateRefT' ω σ m α) (s : σ) : m α := do (a, _) ← x.run s; pure a -namespace StateRefT -variables {σ : Type} {m : Type → Type} {α : Type} +namespace StateRefT' +variables {ω σ : Type} {m : Type → Type} {α : Type} -@[inline] protected def lift (x : m α) : StateRefT σ m α := +@[inline] protected def lift (x : m α) : StateRefT' ω σ m α := fun _ => x -instance [Monad m] : Monad (StateRefT σ m) := inferInstanceAs (Monad (ReaderT _ _)) -instance : HasMonadLift m (StateRefT σ m) := ⟨fun _ => StateRefT.lift⟩ -instance [Monad m] [MonadIO m] : MonadIO (StateRefT σ m) := inferInstanceAs (MonadIO (ReaderT _ _)) - -instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateRefT σ m) (StateRefT σ m') := +instance [Monad m] : Monad (StateRefT' ω σ m) := inferInstanceAs (Monad (ReaderT _ _)) +instance : HasMonadLift m (StateRefT' ω σ m) := ⟨fun _ => StateRefT'.lift⟩ +instance [Monad m] [MonadIO m] : MonadIO (StateRefT' ω σ m) := inferInstanceAs (MonadIO (ReaderT _ _)) +instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateRefT' ω σ m) (StateRefT' ω σ m') := inferInstanceAs (MonadFunctor m m' (ReaderT _ _) (ReaderT _ _)) -@[inline] protected def get [Monad m] [HasMonadLiftT ST m] : StateRefT σ m σ := +@[inline] protected def get [Monad m] [HasMonadLiftT (ST ω) m] : StateRefT' ω σ m σ := fun ref => ref.get -@[inline] protected def set [Monad m] [HasMonadLiftT ST m] (s : σ) : StateRefT σ m PUnit := +@[inline] protected def set [Monad m] [HasMonadLiftT (ST ω) m] (s : σ) : StateRefT' ω σ m PUnit := fun ref => ref.set s -@[inline] protected def modifyGet [Monad m] [HasMonadLiftT ST m] (f : σ → α × σ) : StateRefT σ m α := +@[inline] protected def modifyGet [Monad m] [HasMonadLiftT (ST ω) m] (f : σ → α × σ) : StateRefT' ω σ m α := fun ref => ref.modifyGet f -instance [HasMonadLiftT ST m] [Monad m] : MonadStateOf σ (StateRefT σ m) := -{ get := StateRefT.get, - set := StateRefT.set, - modifyGet := fun α f => StateRefT.modifyGet f } +instance [HasMonadLiftT (ST ω) m] [Monad m] : MonadStateOf σ (StateRefT' ω σ m) := +{ get := StateRefT'.get, + set := StateRefT'.set, + modifyGet := fun α f => StateRefT'.modifyGet f } -instance (ε) [MonadExceptOf ε m] : MonadExceptOf ε (StateRefT σ m) := -{ throw := fun α => StateRefT.lift ∘ throwThe ε, +instance (ε) [MonadExceptOf ε m] : MonadExceptOf ε (StateRefT' ω σ m) := +{ throw := fun α => StateRefT'.lift ∘ throwThe ε, catch := fun α x c s => catchThe ε (x s) (fun e => c e s) } -end StateRefT +end StateRefT' diff --git a/src/Init/System/IO.lean b/src/Init/System/IO.lean index a7cc9e1829..a0a4a63273 100644 --- a/src/Init/System/IO.lean +++ b/src/Init/System/IO.lean @@ -10,6 +10,7 @@ import Init.Data.String.Basic import Init.Data.ByteArray import Init.System.IOError import Init.System.FilePath +import Init.System.ST /-- Like https://hackage.haskell.org/package/ghc-Prim-0.5.2.0/docs/GHC-Prim.html#t:RealWorld. Makes sure we never reorder `IO` operations. @@ -26,8 +27,6 @@ def IO.RealWorld : Type := Unit -/ def EIO (ε : Type) : Type → Type := EStateM ε IO.RealWorld -def ST := EIO Empty - instance monadExceptAdapter {ε ε'} : MonadExceptAdapter ε ε' (EIO ε) (EIO ε') := inferInstanceAs $ MonadExceptAdapter ε ε' (EStateM ε IO.RealWorld) (EStateM ε' IO.RealWorld) @@ -41,7 +40,6 @@ instance (ε : Type) : MonadExceptOf ε (EIO ε) := inferInstanceAs (MonadExcept instance (α ε : Type) : HasOrelse (EIO ε α) := ⟨MonadExcept.orelse⟩ instance {ε : Type} {α : Type} [Inhabited ε] : Inhabited (EIO ε α) := inferInstanceAs (Inhabited (EStateM ε IO.RealWorld α)) -instance : Monad ST := inferInstanceAs (Monad (EIO Empty)) abbrev IO : Type → Type := EIO IO.Error @@ -323,79 +321,15 @@ def setAccessRights (filename : String) (mode : FileRight) : IO Unit := Prim.setAccessRights filename mode.flags /- References -/ -constant RefPointed : PointedType.{0} := arbitrary _ +abbrev Ref (α : Type) := ST.Ref IO.RealWorld α -structure Ref (α : Type) : Type := -(ref : RefPointed.type) (h : Nonempty α) +instance st2eio {ε} : HasMonadLift (ST IO.RealWorld) (EIO ε) := +⟨fun α x s => match x s with + | EStateM.Result.ok a s => EStateM.Result.ok a s + | EStateM.Result.error ex _ => Empty.rec _ ex⟩ -instance Ref.inhabited {α} [Inhabited α] : Inhabited (Ref α) := -⟨{ ref := RefPointed.val, h := Nonempty.intro $ arbitrary _}⟩ - -namespace Prim - -/- Auxiliary definition for showing that `EIO ε α` is inhabited when we have a `Ref α` -/ -private noncomputable def inhabitedFromRef {α} (r : Ref α) : ST α := -pure $ (Classical.inhabitedOfNonempty r.h).default - - -@[extern "lean_io_mk_ref"] -constant mkRef {α} (a : α) : ST (Ref α) := pure { ref := RefPointed.val, h := Nonempty.intro a } -@[extern "lean_io_ref_get"] -constant Ref.get {α} (r : @& Ref α) : ST α := inhabitedFromRef r -@[extern "lean_io_ref_set"] -constant Ref.set {α} (r : @& Ref α) (a : α) : ST Unit := arbitrary _ -@[extern "lean_io_ref_swap"] -constant Ref.swap {α} (r : @& Ref α) (a : α) : ST α := inhabitedFromRef r -@[extern "lean_io_ref_take"] -unsafe constant Ref.take {α} (r : @& Ref α) : ST α := inhabitedFromRef r -@[extern "lean_io_ref_ptr_eq"] -constant Ref.ptrEq {α} (r1 r2 : @& Ref α) : ST Bool := arbitrary _ -@[inline] unsafe def Ref.modifyUnsafe {α : Type} (r : Ref α) (f : α → α) : ST Unit := do -v ← Ref.take r; -Ref.set r (f v) - -@[inline] unsafe def Ref.modifyGetUnsafe {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : ST β := do -v ← Ref.take r; -let (b, a) := f v; -Ref.set r a; -pure b - -@[implementedBy Ref.modifyUnsafe] -def Ref.modify {α : Type} (r : Ref α) (f : α → α) : ST Unit := do -v ← Ref.get r; -Ref.set r (f v) - -@[implementedBy Ref.modifyGetUnsafe] -def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : ST β := do -v ← Ref.get r; -let (b, a) := f v; -Ref.set r a; -pure b - -end Prim - -section - -@[inline] private def liftST {ε α} (x : ST α) : EIO ε α := -fun s => match x s with - | r@(EStateM.Result.error e _) => Empty.rec _ e - | EStateM.Result.ok a s => EStateM.Result.ok a s - -instance ST.monadLift {ε} : HasMonadLift ST (EIO ε) := -{ monadLift := fun α => liftST } - -variables {m : Type → Type} [Monad m] [HasMonadLiftT ST m] - -@[inline] def mkRef {α : Type} (a : α) : m (Ref α) := liftM $ Prim.mkRef a -@[inline] def Ref.get {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.get r -@[inline] def Ref.set {α : Type} (r : Ref α) (a : α) : m Unit := liftM $ Prim.Ref.set r a -@[inline] def Ref.swap {α : Type} (r : Ref α) (a : α) : m α := liftM $ Prim.Ref.swap r a -@[inline] unsafe def Ref.take {α : Type} (r : Ref α) : m α := liftM $ Prim.Ref.take r -@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2 -@[inline] def Ref.modify {α : Type} (r : Ref α) (f : α → α) : m Unit := liftM $ Prim.Ref.modify r f -@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f - -end +def mkRef {α : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST IO.RealWorld) m] (a : α) : m (IO.Ref α) := +ST.mkRef a end IO diff --git a/src/Init/System/ST.lean b/src/Init/System/ST.lean new file mode 100644 index 0000000000..6c731c8296 --- /dev/null +++ b/src/Init/System/ST.lean @@ -0,0 +1,107 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Control.EState +import Init.Control.Reader + +def EST (ε : Type) (σ : Type) : Type → Type := EStateM ε σ +abbrev ST (σ : Type) := EST Empty σ + +instance (ε σ : Type) : Monad (EST ε σ) := inferInstanceAs (Monad (EStateM _ _)) +instance (ε σ : Type) : MonadExceptOf ε (EST ε σ) := inferInstanceAs (MonadExceptOf ε (EStateM _ _)) +instance {ε σ : Type} {α : Type} [Inhabited ε] : Inhabited (EST ε σ α) := inferInstanceAs (Inhabited (EStateM _ _ _)) +instance (σ : Type) : Monad (ST σ) := inferInstanceAs (Monad (EST _ _)) + +-- Auxiliary class for inferring the "state" of `EST` and `ST` monads +class STWorld (σ : outParam Type) (m : Type → Type) + +instance STWorld.trans {σ m n} [STWorld σ m] [HasMonadLift m n] : STWorld σ n := ⟨⟩ +instance STWorld.base {ε σ} : STWorld σ (EST ε σ) := ⟨⟩ + +def runEST {ε α : Type} (x : forall (σ : Type), EST ε σ α) : Except ε α := +match x Unit () with +| EStateM.Result.ok a _ => Except.ok a +| EStateM.Result.error ex _ => Except.error ex + +def runST {α : Type} (x : forall (σ : Type), ST σ α) : α := +match x Unit () with +| EStateM.Result.ok a _ => a +| EStateM.Result.error ex _ => Empty.rec _ ex + +instance st2est {ε σ} : HasMonadLift (ST σ) (EST ε σ) := +⟨fun α x s => match x s with + | EStateM.Result.ok a s => EStateM.Result.ok a s + | EStateM.Result.error ex _ => Empty.rec _ ex⟩ + +namespace ST + +/- References -/ +constant RefPointed : PointedType.{0} := arbitrary _ + +structure Ref (σ : Type) (α : Type) : Type := +(ref : RefPointed.type) (h : Nonempty α) + +instance Ref.inhabited {σ α} [Inhabited α] : Inhabited (Ref σ α) := +⟨{ ref := RefPointed.val, h := Nonempty.intro $ arbitrary _}⟩ + +namespace Prim + +/- Auxiliary definition for showing that `ST σ α` is inhabited when we have a `Ref σ α` -/ +private noncomputable def inhabitedFromRef {σ α} (r : Ref σ α) : ST σ α := +pure $ (Classical.inhabitedOfNonempty r.h).default + +@[extern "lean_io_mk_ref"] +constant mkRef {σ α} (a : α) : ST σ (Ref σ α) := pure { ref := RefPointed.val, h := Nonempty.intro a } +@[extern "lean_io_ref_get"] +constant Ref.get {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r +@[extern "lean_io_ref_set"] +constant Ref.set {σ α} (r : @& Ref σ α) (a : α) : ST σ Unit := arbitrary _ +@[extern "lean_io_ref_swap"] +constant Ref.swap {σ α} (r : @& Ref σ α) (a : α) : ST σ α := inhabitedFromRef r +@[extern "lean_io_ref_take"] +unsafe constant Ref.take {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r +@[extern "lean_io_ref_ptr_eq"] +constant Ref.ptrEq {σ α} (r1 r2 : @& Ref σ α) : ST σ Bool := arbitrary _ + +@[inline] unsafe def Ref.modifyUnsafe {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do +v ← Ref.take r; +Ref.set r (f v) + +@[inline] unsafe def Ref.modifyGetUnsafe {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do +v ← Ref.take r; +let (b, a) := f v; +Ref.set r a; +pure b + +@[implementedBy Ref.modifyUnsafe] +def Ref.modify {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do +v ← Ref.get r; +Ref.set r (f v) + +@[implementedBy Ref.modifyGetUnsafe] +def Ref.modifyGet {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do +v ← Ref.get r; +let (b, a) := f v; +Ref.set r a; +pure b + +end Prim + +section +variables {σ : Type} {m : Type → Type} [Monad m] [HasMonadLiftT (ST σ) m] + +@[inline] def mkRef {α : Type} (a : α) : m (Ref σ α) := liftM $ Prim.mkRef a +@[inline] def Ref.get {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.get r +@[inline] def Ref.set {α : Type} (r : Ref σ α) (a : α) : m Unit := liftM $ Prim.Ref.set r a +@[inline] def Ref.swap {α : Type} (r : Ref σ α) (a : α) : m α := liftM $ Prim.Ref.swap r a +@[inline] unsafe def Ref.take {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.take r +@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref σ α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2 +@[inline] def Ref.modify {α : Type} (r : Ref σ α) (f : α → α) : m Unit := liftM $ Prim.Ref.modify r f +@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref σ α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f + +end + +end ST diff --git a/src/Lean/Data/Options.lean b/src/Lean/Data/Options.lean index 2d6f8a43cc..6feb9f4ab7 100644 --- a/src/Lean/Data/Options.lean +++ b/src/Lean/Data/Options.lean @@ -100,7 +100,7 @@ def monadOptsFromLift (m) {n} [MonadOptions m] [HasMonadLiftT m n] : MonadOption { getOptions := liftM (getOptions : m _) } instance ReaderT.monadOpts {ρ m} [MonadOptions m] : MonadOptions (ReaderT ρ m) := monadOptsFromLift m -instance StateRefT.monadOpts {σ m} [MonadOptions m] : MonadOptions (StateRefT σ m) := monadOptsFromLift m +instance StateRefT.monadOpts {ω σ m} [MonadOptions m] : MonadOptions (StateRefT' ω σ m) := monadOptsFromLift m section Methods diff --git a/src/Lean/Exception.lean b/src/Lean/Exception.lean index 0d63459cce..f2489e80b2 100644 --- a/src/Lean/Exception.lean +++ b/src/Lean/Exception.lean @@ -34,7 +34,7 @@ instance ReaderT.monadError {ρ m} [Monad m] [MonadError m] : MonadError (Reader { getRef := fun _ => getRef, addContext := fun ref msg _ => addContext ref msg } -instance StateRefT.monadError {σ m} [Monad m] [MonadError m] : MonadError (StateRefT σ m) := +instance StateRefT.monadError {ω σ m} [Monad m] [MonadError m] : MonadError (StateRefT' ω σ m) := inferInstanceAs (MonadError (ReaderT _ _)) section Methods diff --git a/src/Lean/MonadEnv.lean b/src/Lean/MonadEnv.lean index 1fb523f710..0d408d3ad0 100644 --- a/src/Lean/MonadEnv.lean +++ b/src/Lean/MonadEnv.lean @@ -21,7 +21,7 @@ def monadEnvFromLift (m) {n} [MonadEnv m] [HasMonadLiftT m n] : MonadEnv n := modifyEnv := fun f => liftM (modifyEnv f : m Unit) } instance ReaderT.monadEnv {m ρ} [Monad m] [MonadEnv m] : MonadEnv (ReaderT ρ m) := monadEnvFromLift m -instance StateRefT.monadEnv {m σ} [MonadEnv m] : MonadEnv (StateRefT σ m) := monadEnvFromLift m +instance StateRefT.monadEnv {ω m σ} [MonadEnv m] : MonadEnv (StateRefT' ω σ m) := monadEnvFromLift m instance OptionT.monadEnv {m} [Monad m] [MonadEnv m] : MonadEnv (OptionT m) := monadEnvFromLift m section Methods diff --git a/src/Lean/Util/Trace.lean b/src/Lean/Util/Trace.lean index 941f5db402..048dec9e1e 100644 --- a/src/Lean/Util/Trace.lean +++ b/src/Lean/Util/Trace.lean @@ -21,7 +21,7 @@ instance ReaderT.monadTracer (ρ : Type) (m : Type → Type) [MonadTracer m] : M trace := fun n x _ => MonadTracer.trace n x, traceM := fun n x ctx => MonadTracer.traceM n (x ctx) } -instance StateRefT.monadTracer (σ : Type) (m : Type → Type) [MonadTracer m] : MonadTracer (StateRefT σ m) := +instance StateRefT.monadTracer (ω σ : Type) (m : Type → Type) [MonadTracer m] : MonadTracer (StateRefT' ω σ m) := inferInstanceAs (MonadTracer (ReaderT _ _)) class MonadTracerAdapter (m : Type → Type) := diff --git a/tests/lean/run/stateRef.lean b/tests/lean/run/stateRef.lean index 09a68aa5dd..c8cdf23aa2 100644 --- a/tests/lean/run/stateRef.lean +++ b/tests/lean/run/stateRef.lean @@ -52,3 +52,15 @@ IO.println $ "state1 " ++ toString a2; pure (a0 + a1 + a2) #eval ((f4.run' ⟨10⟩).run' ⟨20⟩).run' ⟨30⟩ + +abbrev S (ω : Type) := StateRefT Nat $ StateRefT String $ ST ω + +def f5 {ω} : S ω Unit := do +s ← getThe String; +modify fun n => n + s.length; +pure () + +def f5Pure (n : Nat) (s : String) := +runST (fun _ => (f5.run n).run s) + +#eval f5Pure 10 "hello world"