feat: remove outparam from MonadState
We add helper classes with `outParam`.
@Kha This is similar to the `MonadExceptOf` modification.
Motivation: the new `StateRefT` (state monad implemented using
`IO.Ref`) makes is it quite cheap to have multiple states on the
stack. But, we need a mechanism for accessing the different states in
a convenient way.
Note that, I did not add a `MonadStateOf` class, but helper classes
such as `HasGet` which uses `outParam`. I will do the same for `MonadExcept`.
Summary:
- `get` gets the state on the top of the Monad stack
- `getThe σ` gets the state with type `σ`
- `modify f` modifies the state on the top of the Monad stack.
We use `modify fun s => { s with ... }` quite often, and we cannot
infer type of `s` here.
- `modifyThe σ f` allows us to select which state on the stack we are modifying.
- I didn't add `setThe`, since we usually can infer the state type at
`set s`. In the whole codebase, we have only one instance where this
is not true.
This commit is contained in:
parent
f01d45a6c1
commit
5605735137
4 changed files with 49 additions and 12 deletions
|
|
@ -90,7 +90,7 @@ end StateT
|
|||
/-- An implementation of [MonadState](https://hackage.haskell.org/package/mtl-2.2.2/docs/Control-Monad-State-Class.html).
|
||||
In contrast to the Haskell implementation, we use overlapping instances to derive instances
|
||||
automatically from `monadLift`. -/
|
||||
class MonadState (σ : outParam (Type u)) (m : Type u → Type v) :=
|
||||
class MonadState (σ : Type u) (m : Type u → Type v) :=
|
||||
/- Obtain the top-most State of a Monad stack. -/
|
||||
(get : m σ)
|
||||
/- Set the top-most State of a Monad stack. -/
|
||||
|
|
@ -101,16 +101,42 @@ class MonadState (σ : outParam (Type u)) (m : Type u → Type v) :=
|
|||
because the latter does not use the State linearly (without sufficient inlining). -/
|
||||
(modifyGet {α : Type u} : (σ → α × σ) → m α)
|
||||
|
||||
export MonadState (get set modifyGet)
|
||||
export MonadState (set)
|
||||
|
||||
abbrev getThe (σ : Type u) {m : Type u → Type v} [MonadState σ m] : m σ :=
|
||||
MonadState.get
|
||||
|
||||
@[inline] abbrev modifyThe (σ : Type u) {m : Type u → Type v} [MonadState σ m] (f : σ → σ) : m PUnit :=
|
||||
MonadState.modifyGet fun s => (PUnit.unit, f s)
|
||||
|
||||
/-- Class for `get`, but `σ` is an outParam for convenience -/
|
||||
class HasGet (σ : outParam (Type u)) (m : Type u → Type v) :=
|
||||
(get : m σ)
|
||||
|
||||
export HasGet (get)
|
||||
|
||||
instance monadState.hasGet (σ : Type u) (m : Type u → Type v) [MonadState σ m] : HasGet σ m :=
|
||||
{ get := getThe σ }
|
||||
|
||||
class HasModifyGet (σ : outParam (Type u)) (m : Type u → Type v) :=
|
||||
(modifyGet {α : Type u} : (σ → α × σ) → m α)
|
||||
|
||||
export HasModifyGet (modifyGet)
|
||||
|
||||
def MonadState.toModifyGet {σ : Type u} {m : Type u → Type v} (s : MonadState σ m) : HasModifyGet σ m :=
|
||||
{ modifyGet := s.modifyGet }
|
||||
|
||||
instance monadState.hasModifyGet (σ : Type u) (m : Type u → Type v) [MonadState σ m] : HasModifyGet σ m :=
|
||||
{ modifyGet := fun α f => MonadState.modifyGet f }
|
||||
|
||||
section
|
||||
variables {σ : Type u} {m : Type u → Type v}
|
||||
|
||||
@[inline] def modify [MonadState σ m] (f : σ → σ) : m PUnit :=
|
||||
modifyGet (fun s => (PUnit.unit, f s))
|
||||
@[inline] def modify [HasModifyGet σ m] (f : σ → σ) : m PUnit :=
|
||||
modifyGet fun s => (PUnit.unit, f s)
|
||||
|
||||
@[inline] def getModify [MonadState σ m] [Monad m] (f : σ → σ) : m σ := do
|
||||
s ← get; modify f; pure s
|
||||
@[inline] def getModify [HasModifyGet σ m] [Monad m] (f : σ → σ) : m σ := do
|
||||
modifyGet fun s => (s, f s)
|
||||
|
||||
-- NOTE: The Ordering of the following two instances determines that the top-most `StateT` Monad layer
|
||||
-- will be picked first
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ stx ← getCur;
|
|||
idx ← getIdx;
|
||||
st ← get;
|
||||
-- reset prec/prec and store `mkParen` for the recursive call
|
||||
set { stxTrav := st.stxTrav };
|
||||
set { stxTrav := st.stxTrav : State };
|
||||
trace! `PrettyPrinter.parenthesize ("parenthesizing (contPrec := " ++ toString st.contPrec ++ ")" ++ MessageData.nest 2 (line ++ stx));
|
||||
adaptReader (fun (ctx : Context) => { ctx with mkParen := some mkParen }) x;
|
||||
{ minPrec := some minPrec, trailPrec := trailPrec, .. } ← get
|
||||
|
|
|
|||
|
|
@ -377,11 +377,11 @@ namespace MonadTraverser
|
|||
variables {m : Type → Type} [Monad m] [t : MonadTraverser m]
|
||||
|
||||
def getCur : m Syntax := Traverser.cur <$> t.st.get
|
||||
def setCur (stx : Syntax) : m Unit := @modify _ _ t.st (fun t => t.setCur stx)
|
||||
def goDown (idx : Nat) : m Unit := @modify _ _ t.st (fun t => t.down idx)
|
||||
def goUp : m Unit := @modify _ _ t.st (fun t => t.up)
|
||||
def goLeft : m Unit := @modify _ _ t.st (fun t => t.left)
|
||||
def goRight : m Unit := @modify _ _ t.st (fun t => t.right)
|
||||
def setCur (stx : Syntax) : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.setCur stx)
|
||||
def goDown (idx : Nat) : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.down idx)
|
||||
def goUp : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.up)
|
||||
def goLeft : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.left)
|
||||
def goRight : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.right)
|
||||
|
||||
def getIdx : m Nat := do
|
||||
st ← t.st.get;
|
||||
|
|
|
|||
|
|
@ -17,3 +17,14 @@ modify fun s => s + v;
|
|||
get
|
||||
|
||||
#eval (f2.run 10).run' 20
|
||||
|
||||
def f3 : StateT String (StateRefT Nat IO) Nat := do
|
||||
s ← get;
|
||||
n ← getThe Nat;
|
||||
set (s ++ ", " ++ toString n);
|
||||
s ← get;
|
||||
IO.println s;
|
||||
set (n+1);
|
||||
getThe Nat
|
||||
|
||||
#eval (f3.run' "test").run' 10
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue