diff --git a/src/Init/Control/Basic.lean b/src/Init/Control/Basic.lean index 52a3e97b03..58ae5398ba 100644 --- a/src/Init/Control/Basic.lean +++ b/src/Init/Control/Basic.lean @@ -67,7 +67,6 @@ infixr:35 " <&&> " => andM not <$> x /-! - # How `MonadControl` works There is a [tutorial by Alexis King](https://lexi-lambda.github.io/blog/2019/09/07/demystifying-monadbasecontrol/) that this docstring is based on. @@ -76,11 +75,14 @@ Suppose we have `foo : ∀ α, IO α → IO α` and `bar : StateT σ IO β` (ie, We might want to 'map' `bar` by `foo`. Concretely we would write this as: ```lean -def mapped_foo : StateT σ m β → StateT σ m β := do +constant foo : ∀ {α}, IO α → IO α +constant bar : StateT σ IO β + +def mapped_foo : StateT σ IO β := do let s ← get - let (s', v) ← lift <| foo <| StateT.run bar s + let (b, s') ← liftM <| foo <| StateT.run bar s set s' - pure v + return b ``` This is fine but it's not going to generalise, what if we replace `StateT Nat IO` with a large tower of monad transformers? @@ -96,12 +98,21 @@ has the type `IO (σ × β)`. The key idea is that `σ × β` contains all of th Now lets define some values to generalise `mapped_foo`: - Write `IO (σ × β)` as `IO (stM β)` - Write `StateT.run . s` as `mapInBase : StateT σ IO α → IO (stM β)` -- Define `restoreM : IO (stM α) → StateT σ IO α` as `fun x => do let (s', v) ← lift x; set s'; pure v` - -To get +- Define `restoreM : IO (stM α) → StateT σ IO α` as below ```lean -def mapped_foo : StateT σ m β → StateT σ m β := do +def stM (α : Type) := α × σ + +def restoreM (x : IO (stM α)) : StateT σ IO α := do + let (a,s) ← liftM x + set s + return a +``` + +To get: + +```lean +def mapped_foo' : StateT σ IO β := do let s ← get let mapInBase := fun z => StateT.run z s restoreM <| foo <| mapInBase bar @@ -110,19 +121,20 @@ def mapped_foo : StateT σ m β → StateT σ m β := do and finally define ```lean -control : {α : Type u} → (({β : Type u} → StateT σ IO β → IO (stM β)) → IO (stM α)) → StateT σ IO α - | α, f => do - let s ← get - let mapInBase := fun {β} (z : StateT σ IO β) => StateT.run z s - let r : IO (stM α) := f mapInBase - restoreM r +def control {α : Type} + (f : ({β : Type} → StateT σ IO β → IO (stM β)) → IO (stM α)) + : StateT σ IO α := do + let s ← get + let mapInBase := fun {β} (z : StateT σ IO β) => StateT.run z s + let r : IO (stM α) := f mapInBase + restoreM r ``` -now we can write `mapped_foo` as: +Now we can write `mapped_foo` as: ```lean -def mapped_foo : StateT σ m β → StateT σ m β := - control fun mapInBase => foo (mapInBase bar) +def mapped_foo'' : StateT σ IO β := + control (fun mapInBase => foo (mapInBase bar)) ``` The core idea of `mapInBase` is that given any `β`, it runs an instance of @@ -131,7 +143,7 @@ Once it's been through `foo` we can then unpack the state again with `restoreM`. Hence we can apply `foo` to `bar` without losing track of the state. Here `stM β = σ × β` is the 'packaged result state', but we can generalise: -if we have a tower `StateT σ₁ <| StateT σ₂ <| IO`, then we can get the +if we have a tower `StateT σ₁ <| StateT σ₂ <| IO`, then the composite packaged state is going to be `stM₁₂ β := σ₁ × σ₂ × β` or `stM₁₂ := stM₁ ∘ stM₂`. Now we can define `MonadControl m n`. Call `m` the 'base monad', in the above example it was `IO`. @@ -146,12 +158,33 @@ in a new nested metavariable context. We can lift this to `withNewMctxDepth : n Which means that we can also run `withNewMctxDepth` in the `Tactic` monad without needing to faff around with lifts and all the other boilerplate needed in `mapped_foo`. +## Relationship to `MonadFunctor` + +A stricter form of `MonadControl` is `MonadFunctor`, which defines +`monadMap {α} : (∀ {β}, m β → m β) → n α → n α`. Using `monadMap` it is also possible to define `mapped_foo` above. +However there are some mappings which can't be derived using `MonadFunctor`. For example: + +```lean,ignore + @[inline] def map1MetaM [MonadControlT MetaM m] [Monad m] (f : forall {α}, (β → MetaM α) → MetaM α) {α} (k : β → m α) : m α := + control fun runInBase => f fun b => runInBase <| k b + + @[inline] def map2MetaM [MonadControlT MetaM m] [Monad m] (f : forall {α}, (β → γ → MetaM α) → MetaM α) {α} (k : β → γ → m α) : m α := + control fun runInBase => f fun b c => runInBase <| k b c +``` + +In these examples, `MonadControl` is needed because the lifted function +needs to be all-quantified over the monadic return value, +as that is where the surrounding monad's (`n`) state is stored. +In the Lean monad-transformer stacks, there are no `MonadFunctor`s that +are not also `MonadControl`s and so `MonadFunctor` is not used. + -/ + /-- MonadControl is a way of stating that the monad `m` can be 'run inside' the monad `n`. This is the same as [`MonadBaseControl`](https://hackage.haskell.org/package/monad-control-1.0.3.1/docs/Control-Monad-Trans-Control.html#t:MonadBaseControl) in Haskell. -See the comment above this docstring for an explanation of how to use MonadControl. +To learn about `MonadControl`, see the comment above this docstring. -/ class MonadControl (m : Type u → Type v) (n : Type u → Type w) where diff --git a/tests/lean/run/MonadControl_tutorial.lean b/tests/lean/run/MonadControl_tutorial.lean new file mode 100644 index 0000000000..f7f29d5bfe --- /dev/null +++ b/tests/lean/run/MonadControl_tutorial.lean @@ -0,0 +1,42 @@ +namespace Tutorial + +/- This file contains the code examples that are used in the +monad control docstring in "How `MonadControl` works" in src/Init/Control/Basic.lean -/ + +def σ := Nat +@[reducible] +def β := String + +constant foo : ∀ {α}, IO α → IO α +constant bar : StateT σ IO β + +def mapped_foo : StateT σ IO β := do + let s ← get + let (b, s') ← liftM <| foo <| StateT.run bar s + set s' + return b + +def stM (α : Type) := α × σ + +def restoreM (x : IO (stM α)) : StateT σ IO α := do + let (a,s) ← liftM x + set s + return a + +def mapped_foo' : StateT σ IO β := do + let s ← get + let mapInBase := fun z => StateT.run z s + restoreM <| foo <| mapInBase bar + +def control {α : Type} + (f : ({β : Type} → StateT σ IO β → IO (stM β)) → IO (stM α)) + : StateT σ IO α := do + let s ← get + let mapInBase := fun {β} (z : StateT σ IO β) => StateT.run z s + let r : IO (stM α) := f mapInBase + restoreM r + +def mapped_foo'' : StateT σ IO β := + control (fun mapInBase => foo (mapInBase bar)) + +end Tutorial