feat: add MPL specs for slice for ... in (#11141)

This PR provides a polymorphic `ForIn` instance for slices and an MPL
`spec` lemma for the iteration over slices using `for ... in`. It also
provides a version specialized to `Subarray`.
This commit is contained in:
Paul Reichert 2025-11-17 16:58:29 +01:00 committed by GitHub
parent 8671f81aa5
commit 8eb0293098
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 148 additions and 22 deletions

View file

@ -2231,8 +2231,8 @@ theorem push_eq_flatten_iff {xss : Array (Array α)} {ys : Array α} {y : α} :
-- zs = cs ++ ds.flatten := by sorry
/-- Two arrays of subarrays are equal iff their flattens coincide, as well as the sizes of the
subarrays. -/
/-- Two arrays of arrays are equal iff their flattens coincide, as well as the sizes of the
arrays. -/
theorem eq_iff_flatten_eq {xss₁ xss₂ : Array (Array α)} :
xss₁ = xss₂ ↔ xss₁.flatten = xss₂.flatten ∧ map size xss₁ = map size xss₂ := by
cases xss₁ using array₂_induction with

View file

@ -89,6 +89,12 @@ instance {x : γ} {State : Type w} {iter}
IteratorCollect (α := i.State) m n :=
inferInstanceAs <| IteratorCollect (α := State) m n
instance {x : γ} {State : Type w} {iter} [Monad m] [Monad n]
[Iterator (α := State) m β] [IteratorCollect State m n] [LawfulIteratorCollect State m n] :
letI i : ToIterator x m β := .ofM State iter
LawfulIteratorCollect (α := i.State) m n :=
inferInstanceAs <| LawfulIteratorCollect (α := State) m n
instance {x : γ} {State : Type w} {iter}
[Iterator (α := State) m β] [IteratorCollectPartial State m n] :
letI i : ToIterator x m β := .ofM State iter
@ -101,6 +107,12 @@ instance {x : γ} {State : Type w} {iter}
IteratorLoop (α := i.State) m n :=
inferInstanceAs <| IteratorLoop (α := State) m n
instance {x : γ} {State : Type w} {iter} [Monad m] [Monad n]
[Iterator (α := State) m β] [IteratorLoop State m n] [LawfulIteratorLoop State m n]:
letI i : ToIterator x m β := .ofM State iter
LawfulIteratorLoop (α := i.State) m n :=
inferInstanceAs <| LawfulIteratorLoop (α := State) m n
instance {x : γ} {State : Type w} {iter}
[Iterator (α := State) m β] [IteratorLoopPartial State m n] :
letI i : ToIterator x m β := .ofM State iter

View file

@ -7,12 +7,14 @@ module
prelude
public import Init.Data.Slice.Array.Basic
public import Init.Data.Slice.Operations
import Init.Data.Iterators.Combinators.Attach
public import Init.Data.Iterators.Combinators.ULift
import all Init.Data.Range.Polymorphic.Basic
public import Init.Data.Range.Polymorphic.Iterators
public import Init.Data.Slice.Operations
import Init.Omega
import Init.Data.Iterators.Lemmas.Combinators.Monadic.FilterMap
public section
@ -24,7 +26,7 @@ open Std Slice PRange Iterators
variable {shape : RangeShape} {α : Type u}
instance {s : Subarray α} : ToIterator s Id α :=
instance {s : Slice (Internal.SubarrayData α)} : ToIterator s Id α :=
.of _
(Rco.Internal.iter (s.internalRepresentation.start...<s.internalRepresentation.stop)
|>.attachWith (· < s.internalRepresentation.array.size) ?h
@ -39,22 +41,24 @@ where finally
universe v w
@[no_expose] instance {s : Subarray α} : Iterator (ToIterator.State s Id) Id α := inferInstance
@[no_expose] instance {s : Subarray α} : Finite (ToIterator.State s Id) Id := inferInstance
@[no_expose] instance {s : Subarray α} : IteratorCollect (ToIterator.State s Id) Id Id := inferInstance
@[no_expose] instance {s : Subarray α} : IteratorCollectPartial (ToIterator.State s Id) Id Id := inferInstance
@[no_expose] instance {s : Subarray α} {m : Type v → Type w} [Monad m] :
@[no_expose] instance {s : Slice (Internal.SubarrayData α)} : Iterator (ToIterator.State s Id) Id α := inferInstance
@[no_expose] instance {s : Slice (Internal.SubarrayData α)} : Finite (ToIterator.State s Id) Id := inferInstance
@[no_expose] instance {s : Slice (Internal.SubarrayData α)} : IteratorCollect (ToIterator.State s Id) Id Id := inferInstance
@[no_expose] instance {s : Slice (Internal.SubarrayData α)} : LawfulIteratorCollect (ToIterator.State s Id) Id Id := inferInstance
@[no_expose] instance {s : Slice (Internal.SubarrayData α)} : IteratorCollectPartial (ToIterator.State s Id) Id Id := inferInstance
@[no_expose] instance {s : Slice (Internal.SubarrayData α)} {m : Type v → Type w} [Monad m] :
IteratorLoop (ToIterator.State s Id) Id m := inferInstance
@[no_expose] instance {s : Subarray α} {m : Type v → Type w} [Monad m] :
@[no_expose] instance {s : Slice (Internal.SubarrayData α)} {m : Type v → Type w} [Monad m] :
LawfulIteratorLoop (ToIterator.State s Id) Id m := inferInstance
@[no_expose] instance {s : Slice (Internal.SubarrayData α)} {m : Type v → Type w} [Monad m] :
IteratorLoopPartial (ToIterator.State s Id) Id m := inferInstance
instance : SliceSize (Internal.SubarrayData α) where
size s := s.internalRepresentation.stop - s.internalRepresentation.start
@[no_expose]
instance {α : Type u} {m : Type v → Type w} :
ForIn m (Subarray α) α where
forIn xs init f := forIn (Std.Slice.Internal.iter xs) init f
instance {α : Type u} {m : Type v → Type w} [Monad m] :
ForIn m (Subarray α) α :=
inferInstance
/-!
Without defining the following function `Subarray.foldlM`, it is still possible to call

View file

@ -8,6 +8,7 @@ module
prelude
import all Init.Data.Array.Subarray
import all Init.Data.Slice.Array.Basic
import Init.Data.Slice.Lemmas
public import Init.Data.Slice.Array.Iterator
import all Init.Data.Slice.Array.Iterator
import all Init.Data.Slice.Operations
@ -16,11 +17,11 @@ import all Init.Data.Range.Polymorphic.Lemmas
public import Init.Data.Slice.Lemmas
public import Init.Data.Iterators.Lemmas
open Std.Iterators Std.PRange
open Std Std.Iterators Std.PRange Std.Slice
namespace Std.Slice.Array
namespace Subarray
theorem internalIter_rco_eq {α : Type u} {s : Subarray α} :
theorem internalIter_eq {α : Type u} {s : Subarray α} :
Internal.iter s = (Rco.Internal.iter (s.start...<s.stop)
|>.attachWith (· < s.array.size)
(fun out h => h
@ -40,7 +41,7 @@ theorem toList_internalIter {α : Type u} {s : Subarray α} :
|> Rco.lt_upper_of_mem
|> (Nat.lt_of_lt_of_le · s.stop_le_array_size))
|>.map fun i => s.array[i.1]) := by
rw [internalIter_rco_eq, Iter.toList_map, Iter.toList_uLift, Iter.toList_attachWith]
rw [internalIter_eq, Iter.toList_map, Iter.toList_uLift, Iter.toList_attachWith]
simp [Rco.toList]
public instance : LawfulSliceSize (Internal.SubarrayData α) where
@ -50,4 +51,22 @@ public instance : LawfulSliceSize (Internal.SubarrayData α) where
Rco.size_toArray, Rco.size, Rxo.HasSize.size, Rxc.HasSize.size]
omega
end Std.Slice.Array
public theorem toArray_eq_sliceToArray {α : Type u} {s : Subarray α} :
s.toArray = Slice.toArray s := by
simp [Subarray.toArray, Array.ofSubarray]
@[simp]
public theorem forIn_toList {α : Type u} {s : Subarray α}
{m : Type v → Type w} [Monad m] [LawfulMonad m] {γ : Type v} {init : γ}
{f : αγ → m (ForInStep γ)} :
ForIn.forIn s.toList init f = ForIn.forIn s init f :=
Slice.forIn_toList
@[simp]
public theorem forIn_toArray {α : Type u} {s : Subarray α}
{m : Type v → Type w} [Monad m] [LawfulMonad m] {γ : Type v} {init : γ}
{f : αγ → m (ForInStep γ)} :
ForIn.forIn s.toArray init f = ForIn.forIn s init f :=
Slice.forIn_toArray
end Subarray

View file

@ -8,7 +8,9 @@ module
prelude
public import Init.Data.Slice.Operations
import all Init.Data.Slice.Operations
import Init.Data.Iterators.Consumers
import Init.Data.Iterators.Lemmas.Consumers
public import Init.Data.List.Control
public section
@ -23,6 +25,45 @@ theorem Internal.iter_eq_toIteratorIter {γ : Type u} {s : Slice γ}
Internal.iter s = ToIterator.iter s :=
(rfl)
theorem forIn_internalIter {γ : Type u} {β : Type v}
{m : Type w → Type x} [Monad m] {δ : Type w}
[∀ s : Slice γ, ToIterator s Id β]
[∀ s : Slice γ, Iterator (ToIterator.State s Id) Id β]
[∀ s : Slice γ, IteratorLoop (ToIterator.State s Id) Id m]
[∀ s : Slice γ, LawfulIteratorLoop (ToIterator.State s Id) Id m]
[∀ s : Slice γ, Finite (ToIterator.State s Id) Id] {s : Slice γ}
{init : δ} {f : β → δ → m (ForInStep δ)} :
ForIn.forIn (Internal.iter s) init f = ForIn.forIn s init f :=
(rfl)
@[simp]
public theorem forIn_toList {γ : Type u} {β : Type v}
{m : Type w → Type x} [Monad m] [LawfulMonad m] {δ : Type w}
[∀ s : Slice γ, ToIterator s Id β]
[∀ s : Slice γ, Iterator (ToIterator.State s Id) Id β]
[∀ s : Slice γ, IteratorLoop (ToIterator.State s Id) Id m]
[∀ s : Slice γ, LawfulIteratorLoop (ToIterator.State s Id) Id m]
[∀ s : Slice γ, IteratorCollect (ToIterator.State s Id) Id Id]
[∀ s : Slice γ, LawfulIteratorCollect (ToIterator.State s Id) Id Id]
[∀ s : Slice γ, Finite (ToIterator.State s Id) Id] {s : Slice γ}
{init : δ} {f : β → δ → m (ForInStep δ)} :
ForIn.forIn s.toList init f = ForIn.forIn s init f := by
rw [← forIn_internalIter, ← Iter.forIn_toList, Slice.toList]
@[simp]
public theorem forIn_toArray {γ : Type u} {β : Type v}
{m : Type w → Type x} [Monad m] [LawfulMonad m] {δ : Type w}
[∀ s : Slice γ, ToIterator s Id β]
[∀ s : Slice γ, Iterator (ToIterator.State s Id) Id β]
[∀ s : Slice γ, IteratorLoop (ToIterator.State s Id) Id m]
[∀ s : Slice γ, LawfulIteratorLoop (ToIterator.State s Id) Id m]
[∀ s : Slice γ, IteratorCollect (ToIterator.State s Id) Id Id]
[∀ s : Slice γ, LawfulIteratorCollect (ToIterator.State s Id) Id Id]
[∀ s : Slice γ, Finite (ToIterator.State s Id) Id] {s : Slice γ}
{init : δ} {f : β → δ → m (ForInStep δ)} :
ForIn.forIn s.toArray init f = ForIn.forIn s init f := by
rw [← forIn_internalIter, ← Iter.forIn_toArray, Slice.toArray]
theorem Internal.size_eq_count_iter [∀ s : Slice γ, ToIterator s Id β]
[∀ s : Slice γ, Iterator (ToIterator.State s Id) Id β] {s : Slice γ}
[Finite (ToIterator.State s Id) Id]

View file

@ -75,6 +75,14 @@ def toListRev (s : Slice γ) [ToIterator s Id β] [Iterator (ToIterator.State s
[Finite (ToIterator.State s Id) Id] : List β :=
Internal.iter s |>.toListRev
instance {γ : Type u} {β : Type v} [∀ s : Slice γ, ToIterator s Id β]
[∀ s : Slice γ, Iterator (ToIterator.State s Id) Id β]
[∀ s : Slice γ, IteratorLoop (ToIterator.State s Id) Id m]
[∀ s : Slice γ, Finite (ToIterator.State s Id) Id] :
ForIn m (Slice γ) β where
forIn s init f :=
forIn (Internal.iter s) init f
/--
Folds a monadic operation from left to right over the elements in a slice.
An accumulator of type `β` is constructed by starting with `init` and monadically combining each

View file

@ -9,6 +9,8 @@ prelude
public import Std.Do.Triple.Basic
public import Init.Data.Range.Polymorphic.Iterators
import Init.Data.Range.Polymorphic
public import Init.Data.Slice.Array
public import Init.Data.Iterators.ToIterator
-- This public import is a workaround for #10652.
-- Without it, adding the `spec` attribute for `instMonadLiftTOfMonadLift` will fail.
@ -1087,6 +1089,33 @@ theorem Spec.forIn_rii {α β : Type u} {m : Type u → Type v} {ps : PostShape}
simp only [forIn]
apply Spec.forIn'_rii inv step
open Std.Iterators in
@[spec]
theorem Spec.forIn_slice {m : Type w → Type x} {ps : PostShape}
[Monad m] [WPMonad m ps]
{γ : Type u} {β : Type w}
[LawfulMonad m] {δ : Type w}
[∀ s : Slice γ, ToIterator s Id β]
[∀ s : Slice γ, Iterator (ToIterator.State s Id) Id β]
[∀ s : Slice γ, IteratorLoop (ToIterator.State s Id) Id m]
[∀ s : Slice γ, LawfulIteratorLoop (ToIterator.State s Id) Id m]
[∀ s : Slice γ, IteratorCollect (ToIterator.State s Id) Id Id]
[∀ s : Slice γ, LawfulIteratorCollect (ToIterator.State s Id) Id Id]
[∀ s : Slice γ, Finite (ToIterator.State s Id) Id]
{init : δ} {f : β → δ → m (ForInStep δ)}
{xs : Slice γ}
(inv : Invariant xs.toList δ ps)
(step : ∀ pref cur suff (h : xs.toList = pref ++ cur :: suff) b,
Triple
(f cur b)
(inv.1 (⟨pref, cur::suff, h.symm⟩, b))
(fun r => match r with
| .yield b' => inv.1 (⟨pref ++ [cur], suff, by simp [h]⟩, b')
| .done b' => inv.1 (⟨xs.toList, [], by simp⟩, b'), inv.2)) :
Triple (forIn xs init f) (inv.1 (⟨[], xs.toList, rfl⟩, init)) (fun b => inv.1 (⟨xs.toList, [], by simp⟩, b), inv.2) := by
simp only [← Slice.forIn_toList]
exact Spec.forIn_list inv step
@[spec]
theorem Spec.forIn'_array {α β : Type u} {m : Type u → Type v} {ps : PostShape}
[Monad m] [WPMonad m ps]

View file

@ -4,8 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sebastian Graf
-/
import Std.Tactic.Do
import Std.Tactic.Do.Syntax
import Std
import Lean.Elab.Tactic.Do.VCGen
@ -794,6 +792,23 @@ theorem mem_mergeWithAll [LawfulEqCmp cmp] {m₁ m₂ : ExtTreeMap α β cmp} {f
end KimsUnivPolyUseCase
namespace Slices
def subarraySum (xs : Subarray Nat) : Nat := Id.run do
let mut sum := 0
for x in xs do
sum := sum + x
return sum
theorem subarraySum_correct {xs : Subarray Nat} : subarraySum xs = xs.toList.sum := by
generalize h : subarraySum xs = r
apply Id.of_wp_run_eq h
mvcgen
case inv1 => exact ⇓⟨cursor, prefixSum⟩ => ⌜prefixSum = cursor.prefix.sum⌝
all_goals simp_all
end Slices
namespace PatricksFastExp
def naive_expo (x n : Nat) : Nat := Id.run do
@ -817,8 +832,6 @@ def fast_expo (x n : Nat) : Nat := Id.run do
return y
open Std.Do
theorem naive_expo_correct (x n : Nat) : naive_expo x n = x^n := by
generalize h : naive_expo x n = r
apply Id.of_wp_run_eq h