/- Copyright (c) 2020 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude import Init.Control.EState import Init.Control.Reader def EST (ε : Type) (σ : Type) : Type → Type := EStateM ε σ abbrev ST (σ : Type) := EST Empty σ instance (ε σ : Type) : Monad (EST ε σ) := inferInstanceAs (Monad (EStateM _ _)) instance (ε σ : Type) : MonadExceptOf ε (EST ε σ) := inferInstanceAs (MonadExceptOf ε (EStateM _ _)) instance {ε σ : Type} {α : Type} [Inhabited ε] : Inhabited (EST ε σ α) := inferInstanceAs (Inhabited (EStateM _ _ _)) instance (σ : Type) : Monad (ST σ) := inferInstanceAs (Monad (EST _ _)) -- Auxiliary class for inferring the "state" of `EST` and `ST` monads class STWorld (σ : outParam Type) (m : Type → Type) instance STWorld.trans {σ m n} [STWorld σ m] [MonadLift m n] : STWorld σ n := ⟨⟩ instance STWorld.base {ε σ} : STWorld σ (EST ε σ) := ⟨⟩ def runEST {ε α : Type} (x : forall (σ : Type), EST ε σ α) : Except ε α := match x Unit () with | EStateM.Result.ok a _ => Except.ok a | EStateM.Result.error ex _ => Except.error ex def runST {α : Type} (x : forall (σ : Type), ST σ α) : α := match x Unit () with | EStateM.Result.ok a _ => a | EStateM.Result.error ex _ => Empty.rec _ ex instance st2est {ε σ} : MonadLift (ST σ) (EST ε σ) := ⟨fun α x s => match x s with | EStateM.Result.ok a s => EStateM.Result.ok a s | EStateM.Result.error ex _ => Empty.rec _ ex⟩ namespace ST /- References -/ constant RefPointed : PointedType.{0} := arbitrary _ structure Ref (σ : Type) (α : Type) : Type := (ref : RefPointed.type) (h : Nonempty α) instance Ref.inhabited {σ α} [Inhabited α] : Inhabited (Ref σ α) := ⟨{ ref := RefPointed.val, h := Nonempty.intro $ arbitrary _}⟩ namespace Prim /- Auxiliary definition for showing that `ST σ α` is inhabited when we have a `Ref σ α` -/ private noncomputable def inhabitedFromRef {σ α} (r : Ref σ α) : ST σ α := pure $ (Classical.inhabitedOfNonempty r.h).default @[extern "lean_st_mk_ref"] constant mkRef {σ α} (a : α) : ST σ (Ref σ α) := pure { ref := RefPointed.val, h := Nonempty.intro a } @[extern "lean_st_ref_get"] constant Ref.get {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r @[extern "lean_st_ref_set"] constant Ref.set {σ α} (r : @& Ref σ α) (a : α) : ST σ Unit := arbitrary _ @[extern "lean_st_ref_swap"] constant Ref.swap {σ α} (r : @& Ref σ α) (a : α) : ST σ α := inhabitedFromRef r @[extern "lean_st_ref_take"] unsafe constant Ref.take {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r @[extern "lean_st_ref_ptr_eq"] constant Ref.ptrEq {σ α} (r1 r2 : @& Ref σ α) : ST σ Bool := arbitrary _ @[inline] unsafe def Ref.modifyUnsafe {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do v ← Ref.take r; Ref.set r (f v) @[inline] unsafe def Ref.modifyGetUnsafe {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do v ← Ref.take r; let (b, a) := f v; Ref.set r a; pure b @[implementedBy Ref.modifyUnsafe] def Ref.modify {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do v ← Ref.get r; Ref.set r (f v) @[implementedBy Ref.modifyGetUnsafe] def Ref.modifyGet {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do v ← Ref.get r; let (b, a) := f v; Ref.set r a; pure b end Prim section variables {σ : Type} {m : Type → Type} [Monad m] [MonadLiftT (ST σ) m] @[inline] def mkRef {α : Type} (a : α) : m (Ref σ α) := liftM $ Prim.mkRef a @[inline] def Ref.get {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.get r @[inline] def Ref.set {α : Type} (r : Ref σ α) (a : α) : m Unit := liftM $ Prim.Ref.set r a @[inline] def Ref.swap {α : Type} (r : Ref σ α) (a : α) : m α := liftM $ Prim.Ref.swap r a @[inline] unsafe def Ref.take {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.take r @[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref σ α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2 @[inline] def Ref.modify {α : Type} (r : Ref σ α) (f : α → α) : m Unit := liftM $ Prim.Ref.modify r f @[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref σ α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f end end ST