feat: remove outparam from MonadState

We add helper classes with `outParam`.

@Kha This is similar to the `MonadExceptOf` modification.
Motivation: the new `StateRefT` (state monad implemented using
`IO.Ref`) makes is it quite cheap to have multiple states on the
stack. But, we need a mechanism for accessing the different states in
a convenient way.
Note that, I did not add a `MonadStateOf` class, but helper classes
such as `HasGet` which uses `outParam`. I will do the same for `MonadExcept`.

Summary:
- `get` gets the state on the top of the Monad stack
- `getThe σ` gets the state with type `σ`
- `modify f` modifies the state on the top of the Monad stack.
   We use `modify fun s => { s with ... }` quite often, and we cannot
   infer type of `s` here.
- `modifyThe σ f` allows us to select which state on the stack we are modifying.
- I didn't add `setThe`, since we usually can infer the state type at
  `set s`. In the whole codebase, we have only one instance where this
  is not true.
This commit is contained in:
Leonardo de Moura 2020-08-18 14:19:34 -07:00
parent f01d45a6c1
commit 5605735137
4 changed files with 49 additions and 12 deletions

View file

@ -90,7 +90,7 @@ end StateT
/-- An implementation of [MonadState](https://hackage.haskell.org/package/mtl-2.2.2/docs/Control-Monad-State-Class.html).
In contrast to the Haskell implementation, we use overlapping instances to derive instances
automatically from `monadLift`. -/
class MonadState (σ : outParam (Type u)) (m : Type u → Type v) :=
class MonadState (σ : Type u) (m : Type u → Type v) :=
/- Obtain the top-most State of a Monad stack. -/
(get : m σ)
/- Set the top-most State of a Monad stack. -/
@ -101,16 +101,42 @@ class MonadState (σ : outParam (Type u)) (m : Type u → Type v) :=
because the latter does not use the State linearly (without sufficient inlining). -/
(modifyGet {α : Type u} : (σα × σ) → m α)
export MonadState (get set modifyGet)
export MonadState (set)
abbrev getThe (σ : Type u) {m : Type u → Type v} [MonadState σ m] : m σ :=
MonadState.get
@[inline] abbrev modifyThe (σ : Type u) {m : Type u → Type v} [MonadState σ m] (f : σσ) : m PUnit :=
MonadState.modifyGet fun s => (PUnit.unit, f s)
/-- Class for `get`, but `σ` is an outParam for convenience -/
class HasGet (σ : outParam (Type u)) (m : Type u → Type v) :=
(get : m σ)
export HasGet (get)
instance monadState.hasGet (σ : Type u) (m : Type u → Type v) [MonadState σ m] : HasGet σ m :=
{ get := getThe σ }
class HasModifyGet (σ : outParam (Type u)) (m : Type u → Type v) :=
(modifyGet {α : Type u} : (σα × σ) → m α)
export HasModifyGet (modifyGet)
def MonadState.toModifyGet {σ : Type u} {m : Type u → Type v} (s : MonadState σ m) : HasModifyGet σ m :=
{ modifyGet := s.modifyGet }
instance monadState.hasModifyGet (σ : Type u) (m : Type u → Type v) [MonadState σ m] : HasModifyGet σ m :=
{ modifyGet := fun α f => MonadState.modifyGet f }
section
variables {σ : Type u} {m : Type u → Type v}
@[inline] def modify [MonadState σ m] (f : σσ) : m PUnit :=
modifyGet (fun s => (PUnit.unit, f s))
@[inline] def modify [HasModifyGet σ m] (f : σσ) : m PUnit :=
modifyGet fun s => (PUnit.unit, f s)
@[inline] def getModify [MonadState σ m] [Monad m] (f : σσ) : m σ := do
s ← get; modify f; pure s
@[inline] def getModify [HasModifyGet σ m] [Monad m] (f : σσ) : m σ := do
modifyGet fun s => (s, f s)
-- NOTE: The Ordering of the following two instances determines that the top-most `StateT` Monad layer
-- will be picked first

View file

@ -191,7 +191,7 @@ stx ← getCur;
idx ← getIdx;
st ← get;
-- reset prec/prec and store `mkParen` for the recursive call
set { stxTrav := st.stxTrav };
set { stxTrav := st.stxTrav : State };
trace! `PrettyPrinter.parenthesize ("parenthesizing (contPrec := " ++ toString st.contPrec ++ ")" ++ MessageData.nest 2 (line ++ stx));
adaptReader (fun (ctx : Context) => { ctx with mkParen := some mkParen }) x;
{ minPrec := some minPrec, trailPrec := trailPrec, .. } ← get

View file

@ -377,11 +377,11 @@ namespace MonadTraverser
variables {m : Type → Type} [Monad m] [t : MonadTraverser m]
def getCur : m Syntax := Traverser.cur <$> t.st.get
def setCur (stx : Syntax) : m Unit := @modify _ _ t.st (fun t => t.setCur stx)
def goDown (idx : Nat) : m Unit := @modify _ _ t.st (fun t => t.down idx)
def goUp : m Unit := @modify _ _ t.st (fun t => t.up)
def goLeft : m Unit := @modify _ _ t.st (fun t => t.left)
def goRight : m Unit := @modify _ _ t.st (fun t => t.right)
def setCur (stx : Syntax) : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.setCur stx)
def goDown (idx : Nat) : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.down idx)
def goUp : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.up)
def goLeft : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.left)
def goRight : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.right)
def getIdx : m Nat := do
st ← t.st.get;

View file

@ -17,3 +17,14 @@ modify fun s => s + v;
get
#eval (f2.run 10).run' 20
def f3 : StateT String (StateRefT Nat IO) Nat := do
s ← get;
n ← getThe Nat;
set (s ++ ", " ++ toString n);
s ← get;
IO.println s;
set (n+1);
getThe Nat
#eval (f3.run' "test").run' 10