diff --git a/doc/changes.md b/doc/changes.md index 28a37c60d3..582d8574bd 100644 --- a/doc/changes.md +++ b/doc/changes.md @@ -242,7 +242,7 @@ master branch (aka work in progress branch) * `monad.has_monad_lift(_t)` ~> `has_monad_lift(_t)` * `monad.map_comp` ~> `comp_map` -* `state(_t).{read,write}` ~> `{get,put}` (general operations defined on any `monad_state_lift`) +* `state(_t).{read,write}` ~> `{get,put}` (general operations defined on any `monad_state`) * deleted `monad.monad_transformer` * deleted `monad.lift{n}`. Use `f <$> a1 <*> ... <*> an` instead of `monad.lift{n} f a1 ... an`. * merged `has_map` into `functor` diff --git a/library/init/category/state.lean b/library/init/category/state.lean index 613bc8bd65..7a18582c41 100644 --- a/library/init/category/state.lean +++ b/library/init/category/state.lean @@ -78,43 +78,39 @@ section end end state_t -/-- A specialization of `monad_lift` to lifting `state_t` that allows `σ` to be inferred. - - This class is roughly equivalent to `MonadState` from https://hackage.haskell.org/package/mtl-2.2.2/docs/Control-Monad-State-Class.html, - with the important distinction that it is automatically derived via the generic - `has_monad_lift` class. -/ -class monad_state_lift (σ : out_param (Type u)) (m : out_param (Type u → Type v)) (n : Type u → Type w) := -[has_lift : has_monad_lift_t (state_t σ m) n] - -attribute [instance] monad_state_lift.mk -local attribute [instance] monad_state_lift.has_lift +/-- An implementation of [MonadState](https://hackage.haskell.org/package/mtl-2.2.2/docs/Control-Monad-State-Class.html). + In contrast to the Haskell implementation, we use overlapping instances to derive instances + automatically from `monad_lift`. -/ +class monad_state (σ : out_param (Type u)) (m : Type u → Type v) := +(lift {} {α : Type u} : state σ α → m α) section -variables {σ : Type u} {m : Type u → Type v} {n : Type u → Type w} [monad m] [monad_state_lift σ m n] +variables {σ : Type u} {m : Type u → Type v} + +instance monad_state_trans {n : Type u → Type w} [has_monad_lift m n] [monad_state σ m] : monad_state σ n := +⟨λ α x, monad_lift (monad_state.lift x : m α)⟩ + +instance [monad m] : monad_state σ (state_t σ m) := +⟨λ α x, ⟨λ s, pure (x.run s)⟩⟩ + +variables [monad m] [monad_state σ m] /-- Obtain the top-most state of a monad stack. -/ -@[inline] def get : n σ := -@monad_lift _ _ _ _ (state_t.get : state_t σ m _) +@[inline] def get : m σ := +monad_state.lift state_t.get /-- Set the top-most state of a monad stack. -/ -@[inline] def put (st : σ) : n punit := -monad_lift (state_t.put st : state_t σ m _) +@[inline] def put (st : σ) : m punit := +monad_state.lift (state_t.put st) /-- Map the top-most state of a monad stack. Note: `modify f` may be preferable to `f <$> get >>= put` because the latter does not use the state linearly (without sufficient inlining). -/ -@[inline] def modify (f : σ → σ) : n punit := -monad_lift (state_t.modify f : state_t σ m _) +@[inline] def modify (f : σ → σ) : m punit := +monad_state.lift (state_t.modify f) end -/-- Get the state at a specific position in the monad stack. - - Example: -/ -@[inline] def get_type (m : Type u → Type v) {n : Type u → Type w} (σ : Type u) [has_monad_lift_t (state_t σ m) n] [monad m] : n σ := -get - - /-- A specialization of `monad_map` to `state_t` that allows `σ` to be inferred. -/ class monad_state_functor (σ σ' : out_param (Type u)) (m : out_param (Type u → Type v)) (n n' : Type u → Type w) := [functor {} : monad_functor_t (state_t σ m) (state_t σ' m) n n'] diff --git a/library/init/meta/smt/smt_tactic.lean b/library/init/meta/smt/smt_tactic.lean index 07f731e3ad..ad179c959e 100644 --- a/library/init/meta/smt/smt_tactic.lean +++ b/library/init/meta/smt/smt_tactic.lean @@ -63,7 +63,7 @@ section local attribute [reducible] smt_tactic meta instance : monad smt_tactic := by apply_instance meta instance : alternative smt_tactic := by apply_instance -meta instance : monad_state_lift smt_state tactic smt_tactic := by apply_instance +meta instance : monad_state smt_state smt_tactic := by apply_instance end /- We don't use the default state_t lift operation because only diff --git a/tests/lean/interactive/my_tac_class.lean b/tests/lean/interactive/my_tac_class.lean index 28cd1be81f..dd0dcf145a 100644 --- a/tests/lean/interactive/my_tac_class.lean +++ b/tests/lean/interactive/my_tac_class.lean @@ -4,7 +4,7 @@ state_t nat tactic section local attribute [reducible] mytac meta instance : monad mytac := by apply_instance -meta instance : monad_state_lift nat tactic mytac := by apply_instance +meta instance : monad_state nat mytac := by apply_instance meta instance : has_monad_lift tactic mytac := by apply_instance end diff --git a/tests/lean/interactive/rb_map_ts.lean b/tests/lean/interactive/rb_map_ts.lean index cbfcde1561..c0456596bd 100644 --- a/tests/lean/interactive/rb_map_ts.lean +++ b/tests/lean/interactive/rb_map_ts.lean @@ -4,7 +4,7 @@ state_t (name_map nat) tactic section local attribute [reducible] mytac meta instance : monad mytac := by apply_instance -meta instance : monad_state_lift (name_map nat) tactic mytac := by apply_instance +meta instance : monad_state (name_map nat) mytac := by apply_instance meta instance : has_monad_lift tactic mytac := by apply_instance end diff --git a/tests/lean/run/my_tac_class.lean b/tests/lean/run/my_tac_class.lean index df2565d802..272fe7f58c 100644 --- a/tests/lean/run/my_tac_class.lean +++ b/tests/lean/run/my_tac_class.lean @@ -4,7 +4,7 @@ state_t nat tactic section local attribute [reducible] mytac meta instance : monad mytac := by apply_instance -meta instance : monad_state_lift nat tactic mytac := by apply_instance +meta instance : monad_state nat mytac := by apply_instance meta instance : has_monad_lift tactic mytac := by apply_instance end diff --git a/tests/lean/run/state.lean b/tests/lean/run/state.lean index 7111b3c657..212e2cff09 100644 --- a/tests/lean/run/state.lean +++ b/tests/lean/run/state.lean @@ -19,7 +19,7 @@ do 0 ← read, -- unlifted #eval (lifted_test.run 0).run 1 -def infer_test {m n} [monad_state_lift ℕ m n] [monad m] [monad n] : n ℕ := +def infer_test {m} [monad_state ℕ m] [monad m] : m ℕ := do n ← get, -- can infer σ through class inference pure n.succ @@ -43,10 +43,8 @@ do -- zoom in on second elem def bistate_test : state_t ℕ (state_t bool io) unit := do 0 ← get, -- outer state_t wins - -- manual + -- can always lift manually tt ← monad_lift (get : state_t bool io bool), - -- needs to mention inner monad - tt ← get_type io bool, pure () #eval (bistate_test.run 0).run tt