lean4-htt/src/Init/Control/State.lean
2025-10-16 20:27:46 +00:00

200 lines
6.2 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura, Sebastian Ullrich
The State monad transformer.
-/
module
prelude
public import Init.Control.Except
public section
set_option linter.missingDocs true
universe u v w
/--
Adds a mutable state of type `σ` to a monad.
Actions in the resulting monad are functions that take an initial state and return, in `m`, a tuple
of a value and a state.
-/
@[expose] def StateT (σ : Type u) (m : Type u → Type v) (α : Type u) : Type (max u v) :=
σ → m (α × σ)
/--
Executes an action from a monad with added state in the underlying monad `m`. Given an initial
state, it returns a value paired with the final state.
-/
@[always_inline, inline, expose]
def StateT.run {σ : Type u} {m : Type u → Type v} {α : Type u} (x : StateT σ m α) (s : σ) : m (α × σ) :=
x s
/--
Executes an action from a monad with added state in the underlying monad `m`. Given an initial
state, it returns a value, discarding the final state.
-/
@[always_inline, inline, expose]
def StateT.run' {σ : Type u} {m : Type u → Type v} [Functor m] {α : Type u} (x : StateT σ m α) (s : σ) : m α :=
(·.1) <$> x s
/--
A tuple-based state monad.
Actions in `StateM σ` are functions that take an initial state and return a value paired with a
final state.
-/
@[expose, reducible]
def StateM (σ α : Type u) : Type u := StateT σ Id α
instance {σ α} [Subsingleton σ] [Subsingleton α] : Subsingleton (StateM σ α) where
allEq x y := by
apply funext
intro s
match x s, y s with
| (a₁, s₁), (a₂, s₂) =>
rw [Subsingleton.elim a₁ a₂, Subsingleton.elim s₁ s₂]
namespace StateT
section
variable {σ : Type u} {m : Type u → Type v}
variable [Monad m] {α β : Type u}
/--
Returns the given value without modifying the state. Typically used via `Pure.pure`.
-/
@[always_inline, inline, expose]
protected def pure (a : α) : StateT σ m α :=
fun s => pure (a, s)
/--
Sequences two actions. Typically used via the `>>=` operator.
-/
@[always_inline, inline, expose]
protected def bind (x : StateT σ m α) (f : α → StateT σ m β) : StateT σ m β :=
fun s => do let (a, s) ← x s; f a s
/--
Modifies the value returned by a computation. Typically used via the `<$>` operator.
-/
@[always_inline, inline, expose]
protected def map (f : α → β) (x : StateT σ m α) : StateT σ m β :=
fun s => do let (a, s) ← x s; pure (f a, s)
@[always_inline]
instance : Monad (StateT σ m) where
pure := StateT.pure
bind := StateT.bind
map := StateT.map
/--
Recovers from errors. The state is rolled back on error recovery. Typically used via the `<|>`
operator.
-/
@[always_inline, inline]
protected def orElse [Alternative m] {α : Type u} (x₁ : StateT σ m α) (x₂ : Unit → StateT σ m α) : StateT σ m α :=
fun s => x₁ s <|> x₂ () s
/--
Fails with a recoverable error. The state is rolled back on error recovery.
-/
@[always_inline, inline]
protected def failure [Alternative m] {α : Type u} : StateT σ m α :=
fun _ => failure
instance [Alternative m] : Alternative (StateT σ m) where
failure := StateT.failure
orElse := StateT.orElse
/--
Retrieves the current value of the monad's mutable state.
This increments the reference count of the state, which may inhibit in-place updates.
-/
@[always_inline, inline, expose]
protected def get : StateT σ m σ :=
fun s => pure (s, s)
/--
Replaces the mutable state with a new value.
-/
@[always_inline, inline, expose]
protected def set : σ → StateT σ m PUnit :=
fun s' _ => pure (⟨⟩, s')
/--
Applies a function to the current state that both computes a new state and a value. The new state
replaces the current state, and the value is returned.
It is equivalent to `do let (a, s) := f (← StateT.get); StateT.set s; pure a`. However, using
`StateT.modifyGet` may lead to better performance because it doesn't add a new reference to the
state value, and additional references can inhibit in-place updates of data.
-/
@[always_inline, inline, expose]
protected def modifyGet (f : σα × σ) : StateT σ m α :=
fun s => pure (f s)
/--
Runs an action from the underlying monad in the monad with state. The state is not modified.
This function is typically implicitly accessed via a `MonadLiftT` instance as part of [automatic
lifting](lean-manual://section/monad-lifting).
-/
@[always_inline, inline, expose]
protected def lift {α : Type u} (t : m α) : StateT σ m α :=
fun s => do let a ← t; pure (a, s)
instance : MonadLift m (StateT σ m) := ⟨StateT.lift⟩
@[always_inline]
instance (σ m) : MonadFunctor m (StateT σ m) := ⟨fun f x s => f (x s)⟩
@[always_inline]
instance (ε) [MonadExceptOf ε m] : MonadExceptOf ε (StateT σ m) := {
throw := StateT.lift ∘ throwThe ε
tryCatch := fun x c s => tryCatchThe ε (x s) (fun e => c e s)
}
end
end StateT
/--
Creates a suitable implementation of `ForIn.forIn` from a `ForM` instance.
-/
@[always_inline, inline]
def ForM.forIn [Monad m] [ForM (StateT β (ExceptT β m)) ρ α]
(x : ρ) (b : β) (f : α → β → m (ForInStep β)) : m β := do
let g a b := .mk do
match ← f a b with
| .yield b' => pure (.ok (⟨⟩, b'))
| .done b' => pure (.error b')
match ← forM (m := StateT β (ExceptT β m)) (α := α) x g |>.run b |>.run with
| .ok a => pure a.2
| .error a => pure a
section
variable {σ : Type u} {m : Type u → Type v}
instance [Monad m] : MonadStateOf σ (StateT σ m) where
get := StateT.get
set := StateT.set
modifyGet := StateT.modifyGet
end
@[always_inline]
instance StateT.monadControl (σ : Type u) (m : Type u → Type v) [Monad m] : MonadControl m (StateT σ m) where
stM := fun α => α × σ
liftWith := fun f => do let s ← get; liftM (f (fun x => x.run s))
restoreM := fun x => do let (a, s) ← liftM x; set s; pure a
@[always_inline]
instance StateT.tryFinally {m : Type u → Type v} {σ : Type u} [MonadFinally m] [Monad m] : MonadFinally (StateT σ m) where
tryFinally' := fun x h s => do
let ((a, _), (b, s'')) ← tryFinally' (x s) fun
| some (a, s') => h (some a) s'
| none => h none s
pure ((a, b), s'')