lean4-htt/src/Init/Data/Iterators/PostconditionMonad.lean
Paul Reichert 7a47bfa208
feat: flatMap iterator combinator (#10728)
This PR introduces the `flatMap` iterator combinator. It also adds
lemmas relating `flatMap` to `toList` and `toArray`.
2025-10-14 12:50:54 +00:00

267 lines
11 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Reichert
-/
module
prelude
public import Init.Control.Lawful.Basic
public import Init.Data.Subtype.Basic
public import Init.PropLemmas
public import Init.Control.Lawful.MonadLift.Basic
public section
namespace Std.Iterators
/--
`PostconditionT m α` represents an operation in the monad `m` together with a
intrinsic proof that some postcondition holds for the `α` valued monadic result.
It consists of a predicate `P` about `α` and an element of `m ({ a // P a })` and is a helpful tool
for intrinsic verification, notably termination proofs, in the context of iterators.
`PostconditionT m` is a monad if `m` is. However, note that `PostconditionT m α` is a structure,
so that the compiler will generate inefficient code from recursive functions returning
`PostconditionT m α`. Optimizations for `ReaderT`, `StateT` etc. aren't applicable for structures.
Moreover, `PostconditionT m α` is not a well-behaved monad transformer because `PostconditionT.lift`
neither commutes with `pure` nor with `bind`.
-/
@[unbox]
structure PostconditionT (m : Type w → Type w') (α : Type w) where
/--
A predicate that holds for the return value(s) of the `m`-monadic operation.
-/
Property : α → Prop
/--
The actual monadic operation. Its return value is bundled together with a proof that
it satisfies `Property`.
-/
operation : m (Subtype Property)
/--
Lifts an operation from `m` to `PostconditionT m` without asserting any nontrivial postcondition.
Caution: `lift` is not a lawful lift function.
For example, `pure a : PostconditionT m α` is not the same as
`PostconditionT.lift (pure a : m α)`.
-/
@[always_inline, inline, expose]
def PostconditionT.lift {α : Type w} {m : Type w → Type w'} [Functor m] (x : m α) :
PostconditionT m α :=
⟨fun _ => True, (⟨·, .intro⟩) <$> x⟩
@[always_inline, inline, expose]
protected def PostconditionT.pure {m : Type w → Type w'} [Pure m] {α : Type w}
(a : α) : PostconditionT m α :=
⟨fun y => a = y, pure <| ⟨a, rfl⟩⟩
/--
Lifts a monadic value from `m { a : α // P a }` to a value `PostconditionT m α`.
-/
@[always_inline, inline]
def PostconditionT.liftWithProperty {α : Type w} {m : Type w → Type w'} {P : α → Prop}
(x : m { α // P α }) : PostconditionT m α :=
⟨P, x⟩
/--
Given a function `f : α → β`, returns a a function `PostconditionT m α → PostconditionT m β`,
turning `PostconditionT m` into a functor.
The postcondition of the `x.map f` states that the return value is the image under `f` of some
`a : α` satisfying the `x.Property`.
-/
@[always_inline, inline, expose]
protected def PostconditionT.map {m : Type w → Type w'} [Functor m] {α : Type w} {β : Type w}
(f : α → β) (x : PostconditionT m α) : PostconditionT m β :=
⟨fun b => ∃ a : Subtype x.Property, f a.1 = b,
(fun a => ⟨f a.val, _, rfl⟩) <$> x.operation⟩
/--
Given a function `α → PostconditionT m β`, returns a a function
`PostconditionT m α → PostconditionT m β`, turning `PostconditionT m` into a monad.
-/
@[always_inline, inline, expose]
protected def PostconditionT.bind {m : Type w → Type w'} [Monad m] {α : Type w} {β : Type w}
(x : PostconditionT m α) (f : α → PostconditionT m β) : PostconditionT m β :=
⟨fun b => ∃ a, x.Property a ∧ (f a).Property b,
x.operation >>= fun a =>
(fun b =>
⟨b.val, a.val, a.property, b.property⟩) <$> (f a).operation⟩
/--
A version of `bind` that provides a proof of the previous postcondition to the mapping function.
-/
@[always_inline, inline]
protected def PostconditionT.pbind {m : Type w → Type w'} [Monad m] {α : Type w} {β : Type w}
(x : PostconditionT m α) (f : Subtype x.Property → PostconditionT m β) : PostconditionT m β :=
⟨fun b => ∃ a, (f a).Property b,
x.operation >>= fun a => (fun b => ⟨b.val, a, b.property⟩) <$> (f a).operation⟩
/--
Lifts an operation from `m` to `PostConditionT m` and then applies `PostconditionT.map`.
-/
@[always_inline, inline]
protected def PostconditionT.liftMap {m : Type w → Type w'} [Monad m] {α : Type w} {β : Type w}
(f : α → β) (x : m α) : PostconditionT m β :=
⟨fun b => ∃ a, f a = b, (fun a => ⟨f a, a, rfl⟩) <$> x⟩
/--
Converts an operation from `PostConditionT m` to `m`, discarding the postcondition.
-/
@[always_inline, inline]
def PostconditionT.run {m : Type w → Type w'} [Monad m] {α : Type w} (x : PostconditionT m α) :
m α :=
(fun a => a.val) <$> x.operation
instance {m : Type w → Type w'} [Functor m] : Functor (PostconditionT m) where
map := PostconditionT.map
instance {m : Type w → Type w'} [Monad m] : Monad (PostconditionT m) where
pure := PostconditionT.pure
bind := PostconditionT.bind
theorem PostconditionT.pure_eq_pure {m : Type w → Type w'} [Monad m] {α} {a : α} :
pure a = PostconditionT.pure (m := m) a :=
rfl
@[simp]
theorem PostconditionT.property_pure {m : Type w → Type w'} [Monad m] {α : Type w}
{x : α} :
(pure x : PostconditionT m α).Property = (x = ·) := by
rfl
@[simp]
theorem PostconditionT.operation_pure {m : Type w → Type w'} [Monad m] {α : Type w}
{x : α} :
(pure x : PostconditionT m α).operation = pure ⟨x, property_pure (m := m) ▸ rfl⟩ := by
rfl
theorem PostconditionT.ext {m : Type w → Type w'} [Monad m] [LawfulMonad m]
{α : Type w} {x y : PostconditionT m α}
(h : x.Property = y.Property) (h' : (fun p => ⟨p.1, h ▸ p.2⟩) <$> x.operation = y.operation) :
x = y := by
cases x; cases y; cases h
simpa using h'
theorem PostconditionT.ext_iff {m : Type w → Type w'} [Monad m] [LawfulMonad m]
{α : Type w} {x y : PostconditionT m α} :
x = y ↔ ∃ h : x.Property = y.Property, (fun p => ⟨p.1, h ▸ p.2⟩) <$> x.operation = y.operation := by
constructor
· rintro rfl
exact ⟨rfl, by simp⟩
· rintro ⟨h, h'⟩
exact PostconditionT.ext h h'
@[simp]
protected theorem PostconditionT.map_eq_pure_bind {m : Type w → Type w'} [Monad m] [LawfulMonad m]
{α : Type w} {β : Type w} {f : α → β} {x : PostconditionT m α} :
x.map f = x.bind (pure ∘ f) := by
apply PostconditionT.ext <;> simp [PostconditionT.map, PostconditionT.bind]
@[simp]
protected theorem PostconditionT.pure_bind {m : Type w → Type w'} [Monad m] [LawfulMonad m]
{α : Type w} {β : Type w} {f : α → PostconditionT m β} {a : α} :
(pure a : PostconditionT m α).bind f = f a := by
apply PostconditionT.ext <;> simp [pure, PostconditionT.pure, PostconditionT.bind]
@[simp]
protected theorem PostconditionT.bind_pure {m : Type w → Type w'} [Monad m] [LawfulMonad m]
{α : Type w} {x : PostconditionT m α} :
x.bind pure = x := by
apply PostconditionT.ext <;> simp [pure, PostconditionT.pure, PostconditionT.bind]
@[simp]
protected theorem PostconditionT.bind_assoc {m : Type w → Type w'} [Monad m] [LawfulMonad m]
{α β γ: Type w} {x : PostconditionT m α} {f : α → PostconditionT m β} {g : β → PostconditionT m γ} :
(x.bind f).bind g = x.bind (fun a => (f a).bind g) := by
apply PostconditionT.ext
· simp [PostconditionT.bind]
· simp only [PostconditionT.bind, bind_assoc, bind_map_left, map_bind, Functor.map_map]
ext c
constructor
· rintro ⟨b, ⟨a, ha, hb⟩, h⟩
exact ⟨a, ha, b, hb, h⟩
· rintro ⟨a, ha, b, hb, h⟩
exact ⟨b, ⟨a, ha, hb⟩, h⟩
@[simp]
protected theorem PostconditionT.map_pure {m : Type w → Type w'} [Monad m] [LawfulMonad m]
{α : Type w} {β : Type w} {f : α → β} {a : α} :
(pure a : PostconditionT m α).map f = pure (f a) := by
apply PostconditionT.ext <;> simp [pure, PostconditionT.map, PostconditionT.pure]
instance [Monad m] [LawfulMonad m] : LawfulMonad (PostconditionT m) where
map_const {α β} := by ext a x; simp [Functor.mapConst, Functor.map]
id_map {α} x := by simp [Functor.map]
comp_map {α β γ} g h := by intro x; simp [Functor.map]; rfl
seqLeft_eq {α β} x y := by simp [SeqLeft.seqLeft, Functor.map, Seq.seq]; rfl
seqRight_eq {α β} x y := by simp [Seq.seq, SeqRight.seqRight, Functor.map]
pure_seq g x := by simp [Seq.seq, Functor.map]
bind_pure_comp f x := by simp [Functor.map, Bind.bind]; rfl
bind_map f x := by simp [Seq.seq, Functor.map]; rfl
pure_bind x f := PostconditionT.pure_bind
bind_assoc x f g := PostconditionT.bind_assoc
@[simp]
theorem PostconditionT.property_map {m : Type w → Type w'} [Functor m] {α : Type w} {β : Type w}
{x : PostconditionT m α} {f : α → β} {b : β} :
(x.map f).Property b ↔ (∃ a, f a = b ∧ x.Property a) := by
simp only [PostconditionT.map]
apply Iff.intro
· rintro ⟨⟨a, ha⟩, h⟩
exact ⟨a, h, ha⟩
· rintro ⟨a, h, ha⟩
exact ⟨⟨a, ha⟩, h⟩
@[simp]
theorem PostconditionT.operation_map {m : Type w → Type w'} [Functor m] {α : Type w} {β : Type w}
{x : PostconditionT m α} {f : α → β} :
(x.map f).operation =
(fun a => ⟨_, (property_map (m := m)).mpr ⟨a.1, rfl, a.2⟩⟩) <$> x.operation := by
rfl
@[simp]
theorem PostconditionT.operation_bind {m : Type w → Type w'} [Monad m] {α : Type w} {β : Type w}
{x : PostconditionT m α} {f : α → PostconditionT m β} :
(x.bind f).operation = (do
let a ← x.operation
(fun fa => ⟨fa.1, by exact⟨a.1, a.2, fa.2⟩⟩) <$> (f a.1).operation) := by
rfl
theorem PostconditionT.operation_bind' {m : Type w → Type w'} [Monad m] {α : Type w} {β : Type w}
{x : PostconditionT m α} {f : α → PostconditionT m β} :
(x >>= f).operation = (do
let a ← x.operation
(fun fa => ⟨fa.1, by exact⟨a.1, a.2, fa.2⟩⟩) <$> (f a.1).operation) := by
rfl
@[simp]
theorem PostconditionT.property_lift {m : Type w → Type w'} [Functor m] {α : Type w}
{x : m α} : (lift x : PostconditionT m α).Property = (fun _ => True) := by
rfl
@[simp]
theorem PostconditionT.operation_lift {m : Type w → Type w'} [Functor m] {α : Type w}
{x : m α} : (lift x : PostconditionT m α).operation =
(⟨·, property_lift (m := m) ▸ True.intro⟩) <$> x := by
rfl
instance {m : Type w → Type w'} {n : Type w → Type w''} [MonadLift m n] :
MonadLift (PostconditionT m) (PostconditionT n) where
monadLift x := ⟨_, monadLift x.operation⟩
instance PostconditionT.instLawfulMonadLift {m : Type w → Type w'} {n : Type w → Type w''}
[MonadLift m n] [Monad m] [Monad n] [LawfulMonad m] [LawfulMonad n] [LawfulMonadLift m n] :
LawfulMonadLift (PostconditionT m) (PostconditionT n) where
monadLift_pure a := by
simp [MonadLift.monadLift, monadLift, LawfulMonadLift.monadLift_pure, pure,
PostconditionT.pure]
monadLift_bind x f := by
simp only [MonadLift.monadLift, bind, monadLift, LawfulMonadLift.monadLift_bind,
PostconditionT.bind, mk.injEq, heq_eq_eq, true_and]
simp only [map_eq_pure_bind, LawfulMonadLift.monadLift_bind, LawfulMonadLift.monadLift_pure]
end Std.Iterators