This PR introduces an `Iter.step_eq` lemma that fully unfolds an `Iter.step` call, bypassing layers of unfolding.
214 lines
7.8 KiB
Text
214 lines
7.8 KiB
Text
/-
|
||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Paul Reichert
|
||
-/
|
||
module
|
||
|
||
prelude
|
||
public import Init.Data.Iterators.Consumers.Monadic.Loop
|
||
public import Init.Classical
|
||
import Init.ByCases
|
||
import Init.Omega
|
||
|
||
@[expose] public section
|
||
|
||
/-!
|
||
This module provides the iterator combinator `IterM.take`.
|
||
-/
|
||
|
||
namespace Std
|
||
|
||
variable {α : Type w} {m : Type w → Type w'} {β : Type w}
|
||
|
||
/--
|
||
The internal state of the `IterM.take` iterator combinator.
|
||
-/
|
||
@[unbox]
|
||
structure Iterators.Types.Take (α : Type w) (m : Type w → Type w') {β : Type w} [Iterator α m β] where
|
||
/--
|
||
Internal implementation detail of the iterator library.
|
||
Caution: For `take n`, `countdown` is `n + 1`.
|
||
If `countdown` is zero, the combinator only terminates when `inner` terminates.
|
||
-/
|
||
countdown : Nat
|
||
/-- Internal implementation detail of the iterator library -/
|
||
inner : IterM (α := α) m β
|
||
/--
|
||
Internal implementation detail of the iterator library.
|
||
This proof term ensures that a `take` always produces a finite iterator from a productive one.
|
||
-/
|
||
finite : countdown > 0 ∨ Finite α m
|
||
|
||
open Std.Iterators Std.Iterators.Types
|
||
|
||
/--
|
||
Given an iterator `it` and a natural number `n`, `it.take n` is an iterator that outputs
|
||
up to the first `n` of `it`'s values in order and then terminates.
|
||
|
||
**Marble diagram:**
|
||
|
||
```text
|
||
it ---a----b---c--d-e--⊥
|
||
it.take 3 ---a----b---c⊥
|
||
|
||
it ---a--⊥
|
||
it.take 3 ---a--⊥
|
||
```
|
||
|
||
**Termination properties:**
|
||
|
||
* `Finite` instance: only if `it` is productive
|
||
* `Productive` instance: only if `it` is productive
|
||
|
||
**Performance:**
|
||
|
||
This combinator incurs an additional O(1) cost with each output of `it`.
|
||
-/
|
||
@[always_inline, inline]
|
||
def IterM.take [Iterator α m β] (n : Nat) (it : IterM (α := α) m β) :=
|
||
(⟨Take.mk (n + 1) it (Or.inl <| Nat.zero_lt_succ _)⟩ : IterM m β)
|
||
|
||
/--
|
||
This combinator is only useful for advanced use cases.
|
||
|
||
Given a finite iterator `it`, returns an iterator that behaves exactly like `it` but is of the same
|
||
type as `it.take n`.
|
||
|
||
**Marble diagram:**
|
||
|
||
```text
|
||
it ---a----b---c--d-e--⊥
|
||
it.toTake ---a----b---c--d-e--⊥
|
||
```
|
||
|
||
**Termination properties:**
|
||
|
||
* `Finite` instance: always
|
||
* `Productive` instance: always
|
||
|
||
**Performance:**
|
||
|
||
This combinator incurs an additional O(1) cost with each output of `it`.
|
||
-/
|
||
@[always_inline, inline]
|
||
def IterM.toTake [Iterator α m β] [Finite α m] (it : IterM (α := α) m β) :=
|
||
(⟨Take.mk 0 it (Or.inr inferInstance)⟩ : IterM m β)
|
||
|
||
theorem IterM.take.surjective_of_zero_lt {α : Type w} {m : Type w → Type w'} {β : Type w}
|
||
[Iterator α m β] (it : IterM (α := Take α m) m β) (h : 0 < it.internalState.countdown) :
|
||
∃ (it₀ : IterM (α := α) m β) (k : Nat), it = it₀.take k := by
|
||
refine ⟨it.internalState.inner, it.internalState.countdown - 1, ?_⟩
|
||
simp [take, Nat.sub_add_cancel (m := 1) (n := it.internalState.countdown) (by omega)]
|
||
|
||
namespace Iterators.Types
|
||
|
||
inductive Take.PlausibleStep [Iterator α m β] (it : IterM (α := Take α m) m β) :
|
||
(step : IterStep (IterM (α := Take α m) m β) β) → Prop where
|
||
| yield : ∀ {it' out}, it.internalState.inner.IsPlausibleStep (.yield it' out) →
|
||
(h : it.internalState.countdown ≠ 1) → PlausibleStep it (.yield ⟨it.internalState.countdown - 1, it', it.internalState.finite.imp_left (by omega)⟩ out)
|
||
| skip : ∀ {it'}, it.internalState.inner.IsPlausibleStep (.skip it') →
|
||
it.internalState.countdown ≠ 1 → PlausibleStep it (.skip ⟨it.internalState.countdown, it', it.internalState.finite⟩)
|
||
| done : it.internalState.inner.IsPlausibleStep .done → PlausibleStep it .done
|
||
| depleted : it.internalState.countdown = 1 →
|
||
PlausibleStep it .done
|
||
|
||
@[always_inline, inline]
|
||
instance Take.instIterator [Monad m] [Iterator α m β] : Iterator (Take α m) m β where
|
||
IsPlausibleStep := Take.PlausibleStep
|
||
step it :=
|
||
if h : it.internalState.countdown = 1 then
|
||
pure <| .deflate <| .done (.depleted h)
|
||
else do
|
||
match (← it.internalState.inner.step).inflate with
|
||
| .yield it' out h' =>
|
||
pure <| .deflate <| .yield ⟨it.internalState.countdown - 1, it', (it.internalState.finite.imp_left (by omega))⟩ out (.yield h' h)
|
||
| .skip it' h' => pure <| .deflate <| .skip ⟨it.internalState.countdown, it', it.internalState.finite⟩ (.skip h' h)
|
||
| .done h' => pure <| .deflate <| .done (.done h')
|
||
|
||
def Take.Rel (m : Type w → Type w') [Monad m] [Iterator α m β] [Productive α m] :
|
||
IterM (α := Take α m) m β → IterM (α := Take α m) m β → Prop :=
|
||
open scoped Classical in
|
||
if _ : Finite α m then
|
||
InvImage (Prod.Lex Nat.lt_wfRel.rel IterM.TerminationMeasures.Finite.Rel)
|
||
(fun it => (it.internalState.countdown, it.internalState.inner.finitelyManySteps))
|
||
else
|
||
InvImage (Prod.Lex Nat.lt_wfRel.rel IterM.TerminationMeasures.Productive.Rel)
|
||
(fun it => (it.internalState.countdown, it.internalState.inner.finitelyManySkips))
|
||
|
||
theorem Take.rel_of_countdown [Monad m] [Iterator α m β] [Productive α m]
|
||
{it it' : IterM (α := Take α m) m β}
|
||
(h : it'.internalState.countdown < it.internalState.countdown) : Take.Rel m it' it := by
|
||
simp only [Rel]
|
||
split <;> exact Prod.Lex.left _ _ h
|
||
|
||
theorem Take.rel_of_inner [Monad m] [Iterator α m β] [Productive α m] {remaining : Nat}
|
||
{it it' : IterM (α := α) m β}
|
||
(h : it'.finitelyManySkips.Rel it.finitelyManySkips) :
|
||
Take.Rel m (it'.take remaining) (it.take remaining) := by
|
||
simp only [Rel]
|
||
split
|
||
· exact Prod.Lex.right _ (.of_productive h)
|
||
· exact Prod.Lex.right _ h
|
||
|
||
theorem Take.rel_of_zero_of_inner [Monad m] [Iterator α m β]
|
||
{it it' : IterM (α := Take α m) m β}
|
||
(h : it.internalState.countdown = 0) (h' : it'.internalState.countdown = 0)
|
||
(h'' : haveI := it.internalState.finite.resolve_left (by omega); it'.internalState.inner.finitelyManySteps.Rel it.internalState.inner.finitelyManySteps) :
|
||
haveI := it.internalState.finite.resolve_left (by omega)
|
||
Take.Rel m it' it := by
|
||
haveI := it.internalState.finite.resolve_left (by omega)
|
||
simp only [Rel, this, ↓reduceDIte, InvImage, h, h']
|
||
exact Prod.Lex.right _ h''
|
||
|
||
private def Take.instFinitenessRelation [Monad m] [Iterator α m β]
|
||
[Productive α m] :
|
||
FinitenessRelation (Take α m) m where
|
||
Rel := Take.Rel m
|
||
wf := by
|
||
rw [Rel]
|
||
split
|
||
all_goals
|
||
apply InvImage.wf
|
||
refine ⟨fun (a, b) => Prod.lexAccessible (WellFounded.apply ?_ a) (WellFounded.apply ?_) b⟩
|
||
· exact WellFoundedRelation.wf
|
||
· exact WellFoundedRelation.wf
|
||
subrelation {it it'} h := by
|
||
obtain ⟨step, h, h'⟩ := h
|
||
cases h'
|
||
case yield it' out k h' h'' =>
|
||
cases h
|
||
cases it.internalState.finite
|
||
· apply rel_of_countdown
|
||
simp only
|
||
omega
|
||
· by_cases h : it.internalState.countdown = 0
|
||
· simp only [h, Nat.zero_le, Nat.sub_eq_zero_of_le]
|
||
apply rel_of_zero_of_inner h rfl
|
||
exact .single ⟨_, rfl, h'⟩
|
||
· apply rel_of_countdown
|
||
simp only
|
||
omega
|
||
case skip it' out k h' h'' =>
|
||
cases h
|
||
by_cases h : it.internalState.countdown = 0
|
||
· simp only [h]
|
||
apply Take.rel_of_zero_of_inner h rfl
|
||
exact .single ⟨_, rfl, h'⟩
|
||
· obtain ⟨it, k, rfl⟩ := IterM.take.surjective_of_zero_lt it (by omega)
|
||
apply Take.rel_of_inner
|
||
exact IterM.TerminationMeasures.Productive.rel_of_skip h'
|
||
case done _ =>
|
||
cases h
|
||
case depleted _ =>
|
||
cases h
|
||
|
||
instance Take.instFinite [Monad m] [Iterator α m β] [Productive α m] :
|
||
Finite (Take α m) m :=
|
||
by exact Finite.of_finitenessRelation instFinitenessRelation
|
||
|
||
instance Take.instIteratorLoop {n : Type x → Type x'} [Monad m] [Monad n] [Iterator α m β] :
|
||
IteratorLoop (Take α m) m n :=
|
||
.defaultImplementation
|
||
|
||
end Std.Iterators.Types
|