lean4-htt/src/Init/System/ST.lean
2020-08-25 13:54:41 -07:00

107 lines
4.2 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Control.EState
import Init.Control.Reader
def EST (ε : Type) (σ : Type) : Type → Type := EStateM ε σ
abbrev ST (σ : Type) := EST Empty σ
instance (ε σ : Type) : Monad (EST ε σ) := inferInstanceAs (Monad (EStateM _ _))
instance (ε σ : Type) : MonadExceptOf ε (EST ε σ) := inferInstanceAs (MonadExceptOf ε (EStateM _ _))
instance {ε σ : Type} {α : Type} [Inhabited ε] : Inhabited (EST ε σ α) := inferInstanceAs (Inhabited (EStateM _ _ _))
instance (σ : Type) : Monad (ST σ) := inferInstanceAs (Monad (EST _ _))
-- Auxiliary class for inferring the "state" of `EST` and `ST` monads
class STWorld (σ : outParam Type) (m : Type → Type)
instance STWorld.trans {σ m n} [STWorld σ m] [MonadLift m n] : STWorld σ n := ⟨⟩
instance STWorld.base {ε σ} : STWorld σ (EST ε σ) := ⟨⟩
def runEST {ε α : Type} (x : forall (σ : Type), EST ε σ α) : Except ε α :=
match x Unit () with
| EStateM.Result.ok a _ => Except.ok a
| EStateM.Result.error ex _ => Except.error ex
def runST {α : Type} (x : forall (σ : Type), ST σ α) : α :=
match x Unit () with
| EStateM.Result.ok a _ => a
| EStateM.Result.error ex _ => Empty.rec _ ex
instance st2est {ε σ} : MonadLift (ST σ) (EST ε σ) :=
⟨fun α x s => match x s with
| EStateM.Result.ok a s => EStateM.Result.ok a s
| EStateM.Result.error ex _ => Empty.rec _ ex⟩
namespace ST
/- References -/
constant RefPointed : PointedType.{0} := arbitrary _
structure Ref (σ : Type) (α : Type) : Type :=
(ref : RefPointed.type) (h : Nonempty α)
instance Ref.inhabited {σ α} [Inhabited α] : Inhabited (Ref σ α) :=
⟨{ ref := RefPointed.val, h := Nonempty.intro $ arbitrary _}⟩
namespace Prim
/- Auxiliary definition for showing that `ST σ α` is inhabited when we have a `Ref σ α` -/
private noncomputable def inhabitedFromRef {σ α} (r : Ref σ α) : ST σ α :=
pure $ (Classical.inhabitedOfNonempty r.h).default
@[extern "lean_st_mk_ref"]
constant mkRef {σ α} (a : α) : ST σ (Ref σ α) := pure { ref := RefPointed.val, h := Nonempty.intro a }
@[extern "lean_st_ref_get"]
constant Ref.get {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
@[extern "lean_st_ref_set"]
constant Ref.set {σ α} (r : @& Ref σ α) (a : α) : ST σ Unit := arbitrary _
@[extern "lean_st_ref_swap"]
constant Ref.swap {σ α} (r : @& Ref σ α) (a : α) : ST σ α := inhabitedFromRef r
@[extern "lean_st_ref_take"]
unsafe constant Ref.take {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
@[extern "lean_st_ref_ptr_eq"]
constant Ref.ptrEq {σ α} (r1 r2 : @& Ref σ α) : ST σ Bool := arbitrary _
@[inline] unsafe def Ref.modifyUnsafe {σ α : Type} (r : Ref σ α) (f : αα) : ST σ Unit := do
v ← Ref.take r;
Ref.set r (f v)
@[inline] unsafe def Ref.modifyGetUnsafe {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
v ← Ref.take r;
let (b, a) := f v;
Ref.set r a;
pure b
@[implementedBy Ref.modifyUnsafe]
def Ref.modify {σ α : Type} (r : Ref σ α) (f : αα) : ST σ Unit := do
v ← Ref.get r;
Ref.set r (f v)
@[implementedBy Ref.modifyGetUnsafe]
def Ref.modifyGet {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
v ← Ref.get r;
let (b, a) := f v;
Ref.set r a;
pure b
end Prim
section
variables {σ : Type} {m : Type → Type} [Monad m] [MonadLiftT (ST σ) m]
@[inline] def mkRef {α : Type} (a : α) : m (Ref σ α) := liftM $ Prim.mkRef a
@[inline] def Ref.get {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.get r
@[inline] def Ref.set {α : Type} (r : Ref σ α) (a : α) : m Unit := liftM $ Prim.Ref.set r a
@[inline] def Ref.swap {α : Type} (r : Ref σ α) (a : α) : m α := liftM $ Prim.Ref.swap r a
@[inline] unsafe def Ref.take {α : Type} (r : Ref σ α) : m α := liftM $ Prim.Ref.take r
@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref σ α) : m Bool := liftM $ Prim.Ref.ptrEq r1 r2
@[inline] def Ref.modify {α : Type} (r : Ref σ α) (f : αα) : m Unit := liftM $ Prim.Ref.modify r f
@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref σ α) (f : α → β × α) : m β := liftM $ Prim.Ref.modifyGet r f
end
end ST