107 lines
4.2 KiB
Text
107 lines
4.2 KiB
Text
/-
|
||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
prelude
|
||
import Init.Control.EState
|
||
import Init.Control.Reader
|
||
|
||
def EST (ε : Type) (σ : Type) : Type → Type := EStateM ε σ
|
||
abbrev ST (σ : Type) := EST Empty σ
|
||
|
||
instance (ε σ : Type) : Monad (EST ε σ) := inferInstanceAs (Monad (EStateM _ _))
|
||
instance (ε σ : Type) : MonadExceptOf ε (EST ε σ) := inferInstanceAs (MonadExceptOf ε (EStateM _ _))
|
||
instance {ε σ : Type} {α : Type} [Inhabited ε] : Inhabited (EST ε σ α) := inferInstanceAs (Inhabited (EStateM _ _ _))
|
||
instance (σ : Type) : Monad (ST σ) := inferInstanceAs (Monad (EST _ _))
|
||
|
||
-- Auxiliary class for inferring the "state" of `EST` and `ST` monads
|
||
class STWorld (σ : outParam Type) (m : Type → Type)
|
||
|
||
instance STWorld.trans {σ m n} [STWorld σ m] [MonadLift m n] : STWorld σ n := ⟨⟩
|
||
instance STWorld.base {ε σ} : STWorld σ (EST ε σ) := ⟨⟩
|
||
|
||
def runEST {ε α : Type} (x : forall (σ : Type), EST ε σ α) : Except ε α :=
|
||
match x Unit () with
|
||
| EStateM.Result.ok a _ => Except.ok a
|
||
| EStateM.Result.error ex _ => Except.error ex
|
||
|
||
def runST {α : Type} (x : forall (σ : Type), ST σ α) : α :=
|
||
match x Unit () with
|
||
| EStateM.Result.ok a _ => a
|
||
| EStateM.Result.error ex _ => Empty.rec _ ex
|
||
|
||
instance st2est {ε σ} : MonadLift (ST σ) (EST ε σ) :=
|
||
⟨fun α x s => match x s with
|
||
| EStateM.Result.ok a s => EStateM.Result.ok a s
|
||
| EStateM.Result.error ex _ => Empty.rec _ ex⟩
|
||
|
||
namespace ST
|
||
|
||
/- References -/
|
||
constant RefPointed : PointedType.{0} := arbitrary _
|
||
|
||
structure Ref (σ : Type) (α : Type) : Type :=
|
||
(ref : RefPointed.type) (h : Nonempty α)
|
||
|
||
instance Ref.inhabited {σ α} [Inhabited α] : Inhabited (Ref σ α) :=
|
||
⟨{ ref := RefPointed.val, h := Nonempty.intro $ arbitrary _}⟩
|
||
|
||
namespace Prim
|
||
|
||
/- Auxiliary definition for showing that `ST σ α` is inhabited when we have a `Ref σ α` -/
|
||
private noncomputable def inhabitedFromRef {σ α} (r : Ref σ α) : ST σ α :=
|
||
pure $ (Classical.inhabitedOfNonempty r.h).default
|
||
|
||
@[extern "lean_st_mk_ref"]
|
||
constant mkRef {σ α} (a : α) : ST σ (Ref σ α) := pure { ref := RefPointed.val, h := Nonempty.intro a }
|
||
@[extern "lean_st_ref_get"]
|
||
constant Ref.get {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
|
||
@[extern "lean_st_ref_set"]
|
||
constant Ref.set {σ α} (r : @& Ref σ α) (a : α) : ST σ Unit := arbitrary _
|
||
@[extern "lean_st_ref_swap"]
|
||
constant Ref.swap {σ α} (r : @& Ref σ α) (a : α) : ST σ α := inhabitedFromRef r
|
||
@[extern "lean_st_ref_take"]
|
||
unsafe constant Ref.take {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
|
||
@[extern "lean_st_ref_ptr_eq"]
|
||
constant Ref.ptrEq {σ α} (r1 r2 : @& Ref σ α) : ST σ Bool := arbitrary _
|
||
|
||
@[inline] unsafe def Ref.modifyUnsafe {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do
|
||
v ← Ref.take r;
|
||
Ref.set r (f v)
|
||
|
||
@[inline] unsafe def Ref.modifyGetUnsafe {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
|
||
v ← Ref.take r;
|
||
let (b, a) := f v;
|
||
Ref.set r a;
|
||
pure b
|
||
|
||
@[implementedBy Ref.modifyUnsafe]
|
||
def Ref.modify {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do
|
||
v ← Ref.get r;
|
||
Ref.set r (f v)
|
||
|
||
@[implementedBy Ref.modifyGetUnsafe]
|
||
def Ref.modifyGet {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
|
||
v ← Ref.get r;
|
||
let (b, a) := f v;
|
||
Ref.set r a;
|
||
pure b
|
||
|
||
end Prim
|
||
|
||
section
|
||
variables {σ : Type} {m : Type → Type} [Monad m] [MonadLiftT (ST σ) m]
|
||
|
||
@[inline] def mkRef {α : Type} (a : α) : m (Ref σ α) := liftM $ Prim.mkRef a
|
||
@[inline] def Ref.get {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.get r
|
||
@[inline] def Ref.set {α : Type} (r : Ref σ α) (a : α) : m Unit := liftM $ Prim.Ref.set r a
|
||
@[inline] def Ref.swap {α : Type} (r : Ref σ α) (a : α) : m α := liftM $ Prim.Ref.swap r a
|
||
@[inline] unsafe def Ref.take {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.take r
|
||
@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref σ α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2
|
||
@[inline] def Ref.modify {α : Type} (r : Ref σ α) (f : α → α) : m Unit := liftM $ Prim.Ref.modify r f
|
||
@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref σ α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f
|
||
|
||
end
|
||
|
||
end ST
|