112 lines
4.3 KiB
Text
112 lines
4.3 KiB
Text
/-
|
||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
prelude
|
||
import Init.Classical
|
||
import Init.Control.EState
|
||
import Init.Control.Reader
|
||
|
||
def EST (ε : Type) (σ : Type) : Type → Type := EStateM ε σ
|
||
abbrev ST (σ : Type) := EST Empty σ
|
||
|
||
instance (ε σ : Type) : Monad (EST ε σ) := inferInstanceAs (Monad (EStateM _ _))
|
||
instance (ε σ : Type) : MonadExceptOf ε (EST ε σ) := inferInstanceAs (MonadExceptOf ε (EStateM _ _))
|
||
instance {ε σ : Type} {α : Type} [Inhabited ε] : Inhabited (EST ε σ α) := inferInstanceAs (Inhabited (EStateM _ _ _))
|
||
instance (σ : Type) : Monad (ST σ) := inferInstanceAs (Monad (EST _ _))
|
||
|
||
-- Auxiliary class for inferring the "state" of `EST` and `ST` monads
|
||
class STWorld (σ : outParam Type) (m : Type → Type)
|
||
|
||
instance {σ m n} [MonadLift m n] [STWorld σ m] : STWorld σ n := ⟨⟩
|
||
instance {ε σ} : STWorld σ (EST ε σ) := ⟨⟩
|
||
|
||
@[noinline, nospecialize]
|
||
def runEST {ε α : Type} (x : (σ : Type) → EST ε σ α) : Except ε α :=
|
||
match x Unit () with
|
||
| EStateM.Result.ok a _ => Except.ok a
|
||
| EStateM.Result.error ex _ => Except.error ex
|
||
|
||
@[noinline, nospecialize]
|
||
def runST {α : Type} (x : (σ : Type) → ST σ α) : α :=
|
||
match x Unit () with
|
||
| EStateM.Result.ok a _ => a
|
||
| EStateM.Result.error ex _ => nomatch ex
|
||
|
||
instance {ε σ} : MonadLift (ST σ) (EST ε σ) := ⟨fun x s =>
|
||
match x s with
|
||
| EStateM.Result.ok a s => EStateM.Result.ok a s
|
||
| EStateM.Result.error ex _ => nomatch ex⟩
|
||
|
||
namespace ST
|
||
|
||
/- References -/
|
||
constant RefPointed : NonemptyType.{0}
|
||
|
||
structure Ref (σ : Type) (α : Type) : Type where
|
||
ref : RefPointed.type
|
||
h : Nonempty α
|
||
|
||
instance {σ α} [s : Nonempty α] : Nonempty (Ref σ α) :=
|
||
Nonempty.intro { ref := Classical.choice RefPointed.property, h := s }
|
||
|
||
namespace Prim
|
||
|
||
/- Auxiliary definition for showing that `ST σ α` is inhabited when we have a `Ref σ α` -/
|
||
private noncomputable def inhabitedFromRef {σ α} (r : Ref σ α) : ST σ α :=
|
||
let inh : Inhabited α := Classical.inhabited_of_nonempty r.h
|
||
pure default
|
||
|
||
@[extern "lean_st_mk_ref"]
|
||
constant mkRef {σ α} (a : α) : ST σ (Ref σ α) := pure { ref := Classical.choice RefPointed.property, h := Nonempty.intro a }
|
||
@[extern "lean_st_ref_get"]
|
||
constant Ref.get {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
|
||
@[extern "lean_st_ref_set"]
|
||
constant Ref.set {σ α} (r : @& Ref σ α) (a : α) : ST σ Unit
|
||
@[extern "lean_st_ref_swap"]
|
||
constant Ref.swap {σ α} (r : @& Ref σ α) (a : α) : ST σ α := inhabitedFromRef r
|
||
@[extern "lean_st_ref_take"]
|
||
unsafe constant Ref.take {σ α} (r : @& Ref σ α) : ST σ α := inhabitedFromRef r
|
||
@[extern "lean_st_ref_ptr_eq"]
|
||
constant Ref.ptrEq {σ α} (r1 r2 : @& Ref σ α) : ST σ Bool
|
||
|
||
@[inline] unsafe def Ref.modifyUnsafe {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do
|
||
let v ← Ref.take r
|
||
Ref.set r (f v)
|
||
|
||
@[inline] unsafe def Ref.modifyGetUnsafe {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
|
||
let v ← Ref.take r
|
||
let (b, a) := f v
|
||
Ref.set r a
|
||
pure b
|
||
|
||
@[implementedBy Ref.modifyUnsafe]
|
||
def Ref.modify {σ α : Type} (r : Ref σ α) (f : α → α) : ST σ Unit := do
|
||
let v ← Ref.get r
|
||
Ref.set r (f v)
|
||
|
||
@[implementedBy Ref.modifyGetUnsafe]
|
||
def Ref.modifyGet {σ α β : Type} (r : Ref σ α) (f : α → β × α) : ST σ β := do
|
||
let v ← Ref.get r
|
||
let (b, a) := f v
|
||
Ref.set r a
|
||
pure b
|
||
|
||
end Prim
|
||
|
||
section
|
||
variable {σ : Type} {m : Type → Type} [Monad m] [MonadLiftT (ST σ) m]
|
||
|
||
@[inline] def mkRef {α : Type} (a : α) : m (Ref σ α) := liftM <| Prim.mkRef a
|
||
@[inline] def Ref.get {α : Type} (r : Ref σ α) : m α := liftM <| Prim.Ref.get r
|
||
@[inline] def Ref.set {α : Type} (r : Ref σ α) (a : α) : m Unit := liftM <| Prim.Ref.set r a
|
||
@[inline] def Ref.swap {α : Type} (r : Ref σ α) (a : α) : m α := liftM <| Prim.Ref.swap r a
|
||
@[inline] unsafe def Ref.take {α : Type} (r : Ref σ α) : m α := liftM <| Prim.Ref.take r
|
||
@[inline] def Ref.ptrEq {α : Type} (r1 r2 : Ref σ α) : m Bool := liftM <| Prim.Ref.ptrEq r1 r2
|
||
@[inline] def Ref.modify {α : Type} (r : Ref σ α) (f : α → α) : m Unit := liftM <| Prim.Ref.modify r f
|
||
@[inline] def Ref.modifyGet {α : Type} {β : Type} (r : Ref σ α) (f : α → β × α) : m β := liftM <| Prim.Ref.modifyGet r f
|
||
|
||
end
|
||
|
||
end ST
|