diff --git a/src/Init/Control/State.lean b/src/Init/Control/State.lean index b80f0b2cf3..959c285d83 100644 --- a/src/Init/Control/State.lean +++ b/src/Init/Control/State.lean @@ -90,7 +90,7 @@ end StateT /-- An implementation of [MonadState](https://hackage.haskell.org/package/mtl-2.2.2/docs/Control-Monad-State-Class.html). In contrast to the Haskell implementation, we use overlapping instances to derive instances automatically from `monadLift`. -/ -class MonadState (σ : outParam (Type u)) (m : Type u → Type v) := +class MonadState (σ : Type u) (m : Type u → Type v) := /- Obtain the top-most State of a Monad stack. -/ (get : m σ) /- Set the top-most State of a Monad stack. -/ @@ -101,16 +101,42 @@ class MonadState (σ : outParam (Type u)) (m : Type u → Type v) := because the latter does not use the State linearly (without sufficient inlining). -/ (modifyGet {α : Type u} : (σ → α × σ) → m α) -export MonadState (get set modifyGet) +export MonadState (set) + +abbrev getThe (σ : Type u) {m : Type u → Type v} [MonadState σ m] : m σ := +MonadState.get + +@[inline] abbrev modifyThe (σ : Type u) {m : Type u → Type v} [MonadState σ m] (f : σ → σ) : m PUnit := +MonadState.modifyGet fun s => (PUnit.unit, f s) + +/-- Class for `get`, but `σ` is an outParam for convenience -/ +class HasGet (σ : outParam (Type u)) (m : Type u → Type v) := +(get : m σ) + +export HasGet (get) + +instance monadState.hasGet (σ : Type u) (m : Type u → Type v) [MonadState σ m] : HasGet σ m := +{ get := getThe σ } + +class HasModifyGet (σ : outParam (Type u)) (m : Type u → Type v) := +(modifyGet {α : Type u} : (σ → α × σ) → m α) + +export HasModifyGet (modifyGet) + +def MonadState.toModifyGet {σ : Type u} {m : Type u → Type v} (s : MonadState σ m) : HasModifyGet σ m := +{ modifyGet := s.modifyGet } + +instance monadState.hasModifyGet (σ : Type u) (m : Type u → Type v) [MonadState σ m] : HasModifyGet σ m := +{ modifyGet := fun α f => MonadState.modifyGet f } section variables {σ : Type u} {m : Type u → Type v} -@[inline] def modify [MonadState σ m] (f : σ → σ) : m PUnit := -modifyGet (fun s => (PUnit.unit, f s)) +@[inline] def modify [HasModifyGet σ m] (f : σ → σ) : m PUnit := +modifyGet fun s => (PUnit.unit, f s) -@[inline] def getModify [MonadState σ m] [Monad m] (f : σ → σ) : m σ := do -s ← get; modify f; pure s +@[inline] def getModify [HasModifyGet σ m] [Monad m] (f : σ → σ) : m σ := do +modifyGet fun s => (s, f s) -- NOTE: The Ordering of the following two instances determines that the top-most `StateT` Monad layer -- will be picked first diff --git a/src/Lean/PrettyPrinter/Parenthesizer.lean b/src/Lean/PrettyPrinter/Parenthesizer.lean index d3e208a7c6..02feef5345 100644 --- a/src/Lean/PrettyPrinter/Parenthesizer.lean +++ b/src/Lean/PrettyPrinter/Parenthesizer.lean @@ -191,7 +191,7 @@ stx ← getCur; idx ← getIdx; st ← get; -- reset prec/prec and store `mkParen` for the recursive call -set { stxTrav := st.stxTrav }; +set { stxTrav := st.stxTrav : State }; trace! `PrettyPrinter.parenthesize ("parenthesizing (contPrec := " ++ toString st.contPrec ++ ")" ++ MessageData.nest 2 (line ++ stx)); adaptReader (fun (ctx : Context) => { ctx with mkParen := some mkParen }) x; { minPrec := some minPrec, trailPrec := trailPrec, .. } ← get diff --git a/src/Lean/Syntax.lean b/src/Lean/Syntax.lean index ad9f5804b1..cda29b10e5 100644 --- a/src/Lean/Syntax.lean +++ b/src/Lean/Syntax.lean @@ -377,11 +377,11 @@ namespace MonadTraverser variables {m : Type → Type} [Monad m] [t : MonadTraverser m] def getCur : m Syntax := Traverser.cur <$> t.st.get -def setCur (stx : Syntax) : m Unit := @modify _ _ t.st (fun t => t.setCur stx) -def goDown (idx : Nat) : m Unit := @modify _ _ t.st (fun t => t.down idx) -def goUp : m Unit := @modify _ _ t.st (fun t => t.up) -def goLeft : m Unit := @modify _ _ t.st (fun t => t.left) -def goRight : m Unit := @modify _ _ t.st (fun t => t.right) +def setCur (stx : Syntax) : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.setCur stx) +def goDown (idx : Nat) : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.down idx) +def goUp : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.up) +def goLeft : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.left) +def goRight : m Unit := @modify _ _ t.st.toModifyGet (fun t => t.right) def getIdx : m Nat := do st ← t.st.get; diff --git a/tests/lean/run/stateRef.lean b/tests/lean/run/stateRef.lean index 9e8be0b172..1159aacc4a 100644 --- a/tests/lean/run/stateRef.lean +++ b/tests/lean/run/stateRef.lean @@ -17,3 +17,14 @@ modify fun s => s + v; get #eval (f2.run 10).run' 20 + +def f3 : StateT String (StateRefT Nat IO) Nat := do +s ← get; +n ← getThe Nat; +set (s ++ ", " ++ toString n); +s ← get; +IO.println s; +set (n+1); +getThe Nat + +#eval (f3.run' "test").run' 10