lean4-htt/tests/elab/grind_qsort.lean
Garmelon 08eb78a5b2
chore: switch to new test/bench suite (#12590)
This PR sets up the new integrated test/bench suite. It then migrates
all benchmarks and some related tests to the new suite. There's also
some documentation and some linting.

For now, a lot of the old tests are left alone so this PR doesn't become
even larger than it already is. Eventually, all tests should be migrated
to the new suite though so there isn't a confusing mix of two systems.
2026-02-25 13:51:53 +00:00

375 lines
17 KiB
Text
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
module
import all Init.Data.Array.QSort.Basic
-- TODO: when `grind` is ready for production use, move this file to `src/Init/Data/Array/QSort/Lemmas.lean`.
/-!
# Verification of `Array.qsort`
This file contains a verification of the `Array.qsort` function,
using the `grind` tactic.
The theorems are:
* `size_qsort : (qsort as lt lo hi).size = as.size`
* `qsort_perm : qsort as lt lo hi ~ as`
And when `lt` is antisymmetric and `¬ lt a b` is transitive, we have:
* `qsort_sorted' : lo ≤ i < j ≤ hi → ¬ lt (as.qsort lt lo hi)[j] (as.qsort lt lo hi)[i]`
* `qsort_sorted : i < j → ¬ lt (as.qsort lt)[j] (as.qsort lt)[i]`
(There is not currently a public theorem that `(qsort as lt lo hi)[i] = as[i]` when `i < lo` or `hi < i`.)
-/
namespace Array
open List Vector
-- These attributes still need to be moved to the standard library.
-- set_option trace.grind.ematch.pattern true in
-- attribute [grind] Vector.getElem?_eq_getElem -- This one requires some consideration! -- Probably not need, see Vector.Perm.extract' below.
-- Hmm, we don't seem to have the Array analogues of these!
attribute [grind] Vector.toArray_perm_iff
attribute [grind] Vector.perm_toArray_iff
attribute [grind] Vector.swap_perm
attribute [grind] List.Perm.refl
attribute [grind] Array.Perm.refl
attribute [grind] Vector.Perm.refl
-- attribute [grind] Array.Perm.extract
-- attribute [grind] Vector.Perm.extract
-- These are just the patterns resulting from `grind`, but the behaviour should be explained!
grind_pattern List.Perm.trans => l₁ ~ l₂, l₁ ~ l₃
grind_pattern Array.Perm.trans => xs ~ ys, xs ~ zs
grind_pattern Vector.Perm.trans => xs ~ ys, xs ~ zs
/-- Variant of `List.Perm.take` specifying the permutation is constant after `i` elementwise. -/
theorem _root_.List.Perm.take_of_getElem {l₁ l₂ : List α} (h : l₁ ~ l₂) {i : Nat}
(w : ∀ j, i ≤ j → (_ : j < l₁.length) → l₁[j] = l₂[j]'(by have := h.length_eq; omega)) :
l₁.take i ~ l₂.take i := by
apply h.take_of_getElem?
sorry
/-- Variant of `List.Perm.drop` specifying the permutation is constant before `i` elementwise. -/
theorem _root_.List.Perm.drop_of_getElem {l₁ l₂ : List α} (h : l₁ ~ l₂) {i : Nat}
(w : ∀ j, j < i → (_ : j < l₁.length) → l₁[j] = l₂[j]'(by have := h.length_eq; omega)) :
l₁.drop i ~ l₂.drop i := by
apply h.drop_of_getElem?
sorry
theorem _root_.Array.Perm.extract' {xs ys : Array α} (h : xs ~ ys) {lo hi : Nat}
(wlo : ∀ i, i < lo → (_ : i < xs.size) → xs[i] = ys[i]'(by have := h.size_eq; omega))
(whi : ∀ i, hi ≤ i → (_ : i < xs.size) → xs[i] = ys[i]'(by have := h.size_eq; omega)) :
xs.extract lo hi ~ ys.extract lo hi := by
rcases xs with ⟨xs⟩
rcases ys with ⟨ys⟩
simp_all only [perm_iff_toList_perm, List.getElem?_toArray, List.extract_toArray,
List.extract_eq_take_drop]
apply List.Perm.take_of_getElem (w := fun i h₁ h₂ => by simpa using whi (lo + i) (by omega) sorry)
apply List.Perm.drop_of_getElem (w := wlo)
simpa using List.perm_iff_toArray_perm.mpr h
theorem _root_.Vector.Perm.extract' {xs ys : Vector α n} (h : xs ~ ys) {lo hi : Nat}
(wlo : ∀ i, i < lo → (_ : i < n) → xs[i] = ys[i]) (whi : ∀ i, hi ≤ i → (_ : i < n) → xs[i] = ys[i]) :
xs.extract lo hi ~ ys.extract lo hi := by
rcases xs with ⟨xs, rfl⟩
rcases ys with ⟨ys, h⟩
exact ⟨Array.Perm.extract' h.toArray (by simpa using wlo) (by simpa using whi)⟩
attribute [grind] Array.Perm.extract'
attribute [grind] Vector.Perm.extract'
variable (lt : αα → Bool) (lo hi : Nat)
@[simp, grind =] theorem size_qsort (as : Array α) :
(qsort as lt lo hi).size = as.size := by
grind [qsort]
private theorem qpartition_loop_perm (as : Vector α n)
(hhi : hi < n) (ilo : lo ≤ i) (ik : i ≤ k) (w : k ≤ hi) :
(qpartition.loop lt lo hi hhi pivot as i k).2 ~ as := by
fun_induction qpartition.loop with grind
@[local grind]
private theorem qpartition_perm
(as : Vector α n) (w : lo ≤ hi) (hlo : lo < n) (hhi : hi < n) :
(qpartition as lt lo hi).2 ~ as := by
unfold qpartition
refine Vector.Perm.trans (qpartition_loop_perm ..) ?_
repeat' first
| split
| grind
| refine Vector.Perm.trans (Vector.swap_perm ..) ?_
private theorem qsort_sort_perm
(as : Vector α n) (w : lo ≤ hi) (hlo : lo < n) (hhi : hi < n) :
qsort.sort lt as lo hi w hlo hhi ~ as := by
fun_induction qsort.sort with grind
grind_pattern qsort_sort_perm => qsort.sort lt as lo hi w hlo hhi
theorem qsort_perm (as : Array α) : qsort as lt lo hi ~ as := by
grind [qsort]
private theorem getElem_qpartition_loop_snd_of_lt_lo
(hhi : hi < n) (as : Vector α n) (i k : Nat) (ilo : lo ≤ i) (ik : i ≤ k) (w : k ≤ hi) (w' : lo ≤ hi)
(l : Nat) (h : l < lo) : (qpartition.loop lt lo hi hhi pivot as i k).2[l] = as[l] := by
fun_induction qpartition.loop <;> grind
private theorem getElem_qpartition_snd_of_lt_lo (as : Vector α n)
(hhi : hi < n) (w : lo ≤ hi)
(k : Nat) (h : k < lo) : (qpartition as lt lo hi).2[k] = as[k] := by
grind [qpartition, getElem_qpartition_loop_snd_of_lt_lo]
@[local grind =] private theorem getElem_qsort_sort_of_lt_lo
(as : Vector α n)
(hlo : lo < n) (hhi : hi < n) (w : lo ≤ hi)
(i : Nat) (h : i < lo) : (qsort.sort lt as lo hi)[i] = as[i] := by
fun_induction qsort.sort with grind [getElem_qpartition_snd_of_lt_lo]
private theorem getElem_qpartition_loop_snd_of_hi_lt
(hhi : hi < n) (as : Vector α n) (i k)
(ilo : lo ≤ i) (ik : i ≤ k) (w : k ≤ hi) (w' : lo ≤ hi) (z : i ≤ hi)
(l : Nat) (h : hi < l) (h' : l < n) : (qpartition.loop lt lo hi hhi pivot as i k).2[l] = as[l] := by
fun_induction qpartition.loop <;> grind
private theorem getElem_qpartition_snd_of_hi_lt (as : Vector α n)
(hhi : hi < n) (w : lo ≤ hi)
(k : Nat) (h : hi < k) (h' : k < n) : (qpartition as lt lo hi).2[k] = as[k] := by
grind [qpartition, getElem_qpartition_loop_snd_of_hi_lt]
@[local grind] private theorem getElem_qsort_sort_of_hi_lt
(as : Vector α n) (w : lo ≤ hi)
(hlo : lo < n) (hhi : hi < n) (w : lo ≤ hi)
(i : Nat) (h : hi < i) (h' : i < n) : (qsort.sort lt as lo hi)[i] = as[i] := by
fun_induction qsort.sort with grind [getElem_qpartition_snd_of_hi_lt]
private theorem extract_qsort_sort_perm (as : Vector α n) (lt : αα → Bool)
(hlo := by grind) (hhi := by grind) (w : lo ≤ hi := by grind) :
((qsort.sort lt as lo hi).extract lo (hi + 1)) ~ (as.extract lo (hi + 1)) := by
grind [qsort_sort_perm]
private theorem getElem_qsort_sort_mem
(as : Vector α n) (hhi : hi < n) (i : Nat) (h : i < n) (_ : lo ≤ i) (_ : i ≤ hi) :
(qsort.sort lt as lo hi)[i] ∈ as.extract lo (hi + 1) := by
rw [← (extract_qsort_sort_perm lo hi as lt).mem_iff, Vector.mem_extract_iff_getElem]
exact ⟨i - lo, by grind⟩
private theorem qpartition_loop_spec₁
(hhi : hi < n) (ilo : lo ≤ i) (ik : i ≤ k) (w : k < n) (khi : k ≤ hi)
(as : Vector α n) (hpivot : pivot = as[hi])
(q : ∀ l, (hk₁ : lo ≤ l) → (hk₂ : l < i) → lt as[l] as[hi]) (mid as')
(w_mid : mid = (qpartition.loop lt lo hi hhi pivot as i k).fst.1) (hmid : mid < n)
(w_as : as' = (qpartition.loop lt lo hi hhi pivot as i k).2) :
∀ l, (h₁ : lo ≤ l) → (h₂ : l < mid) → lt as'[l] as'[mid] := by
fun_induction qpartition.loop with unfold qpartition.loop at w_mid w_as
| case1
| case2 => apply_assumption <;> grind
| case3 => grind
private theorem qpartition_loop_spec₂
(hhi : hi < n) (ilo : lo ≤ i) (ik : i ≤ k) (w : k < n) (khi : k ≤ hi)
(as : Vector α n) (hpivot : pivot = as[hi])
(q : ∀ l, (hk₁ : i ≤ l) → (hk₂ : l < k) → !lt as[l] as[hi]) (mid as')
(w_mid : mid = (qpartition.loop lt lo hi hhi pivot as i k).fst.1) (hmid : mid < n)
(w_as : as' = (qpartition.loop lt lo hi hhi pivot as i k).2) :
∀ l, (h₁ : mid < l) → (h₂ : l ≤ hi) → lt as'[l] as'[mid] = false := by
fun_induction qpartition.loop with grind
/--
All elements in the active range before the pivot, are less than the pivot.
-/
private theorem qpartition_spec₁
(hhi : hi < n) (w : lo ≤ hi)
(as : Vector α n) (mid as')
(w_mid : mid = (qpartition as lt lo hi).fst.1) (hmid : mid < n)
(w_as : as' = (qpartition as lt lo hi).2) :
∀ i, (h₁ : lo ≤ i) → (h₂ : i < mid) → lt as'[i] as'[mid] := by
set_option trace.grind.ematch.pattern true in
grind [qpartition, qpartition_loop_spec₁]
/--
All elements in the active range after the pivot, are greater than or equal to the pivot.
-/
private theorem qpartition_spec₂
(hhi : hi < n) (w : lo ≤ hi)
(as : Vector α n) (mid as')
(w_mid : mid = (qpartition as lt lo hi).fst.1) (hmid : mid < n)
(w_as : as' = (qpartition as lt lo hi).2) :
∀ i, (h₁ : mid < i) → (h₂ : i ≤ hi) → lt as'[i] as'[mid] = false := by
grind [qpartition, qpartition_loop_spec₂]
/-!
We now need to deal with a corner case:
we need to show that `qpartition` only returns a value `≥ hi` when `hi ≤ lo`
(and hence the slice of the array between `lo` and `hi` (inclusive) is trivially already sorted).
We prove two preliminary lemmas about `qpartition.loop`.
-/
/-- If we already have `i < k`, then we're sure to return something less than `hi`. -/
private theorem qpartition_loop_lt_hi₁
(h : lo < hi) (ilo : lo ≤ i) (ik : i < k) (w : k ≤ hi) (z : k ≤ hi) (ik' : i ≤ k) :
(qpartition.loop lt lo hi hhi pivot as i k).1.val < hi := by
fun_induction qpartition.loop with grind
/--
Otherwise, if there is some position `k' ≥ k` which is greater than or equal to the pivot,
then when we reach that we'll be sure `i < k`, and hence the previous lemma will apply,
and so we're sure to return something less than `hi`.
-/
private theorem qpartition_loop_lt_hi₂
{as : Vector α n} (h : lo < hi) (ilo : lo ≤ i) (ik : i ≤ k) (w : k < n) (z : k ≤ hi)
(q : ∃ (k' : Nat) (hj' : k' < n), k' ≥ k ∧ k' < hi ∧ ¬ lt as[k'] pivot) :
(qpartition.loop lt lo hi hhi pivot as i k).1.val < hi := by
fun_induction qpartition.loop with
| case1 =>
-- It would be nice if a more aggressive mode in `grind` would do this.
apply_assumption <;> grind
| case2 => grind [qpartition_loop_lt_hi₁]
| case3 => grind
/-- The only way `qpartition` returns a pivot position `≥ hi` is if `hi ≤ lo`. -/
private theorem qpartition_fst_lt_hi (lt_asymm : ∀ {a b}, lt a b → ¬ lt b a)
(as : Vector α n) (hhi : hi < n) (w : lo < hi) : (qpartition as lt lo hi).fst.1 < hi := by
apply qpartition_loop_lt_hi₂ lt lo hi w
· grind
· exact ⟨(lo + hi)/2, by grind⟩
private theorem qsort_sort_spec
(lt_asymm : ∀ {a b}, lt a b → ¬ lt b a)
(le_trans : ∀ {a b c}, ¬ lt b a → ¬ lt c b → ¬ lt c a)
(as : Vector α n) (lo hi : Nat) (hhi : hi < n) (w : lo ≤ hi)
(as' : Vector α n) (w_as : as' = qsort.sort lt as lo hi) :
∀ i, (h₁ : lo ≤ i) → (h₂ : i < hi) → ¬ lt (as')[i + 1] as'[i] := by
unfold qsort.sort at w_as
split at w_as <;> rename_i w₁
· -- The interesting case, where `lo < hi`.
intro i h₁ h₂
-- Decompose `qpartition as lt lo hi` into `mid` (the pivot) and `as'` (the partitioned array).
split at w_as <;> rename_i mid hmid as' w₂
split at w_as <;> rename_i w₃
· -- If the pivot was at least `hi`, then we get a contradiction from `lo < hi`.
simp only [Prod.ext_iff, Subtype.ext_iff] at w₂
obtain ⟨rfl, rfl⟩ := w₂
have := qpartition_fst_lt_hi lt lo hi lt_asymm as hhi w₁
grind
· -- Now we know `lo ≤ mid < hi`.
subst w_as
if p₁ : i < mid then
-- If `i < mid`, then the second stage of sorting is only
-- moving elements above where we're looking.
rw [getElem_qsort_sort_of_lt_lo (i := i)]
rw [getElem_qsort_sort_of_lt_lo (i := i + 1)]
-- And so we can apply the theorem recursively replacing `hi` with `mid`.
apply qsort_sort_spec lt_asymm le_trans as' lo mid
-- The remaining arithmetic side conditions are easily resolved.
all_goals grind
else
replace p₁ : mid ≤ i := by grind
-- If `mid ≤ i`, we need to consider two cases.
if p₃ : mid = i then
-- The tricky case, where `mid = i`.
subst i
-- On the right hand side, the index is below the range where the second stage of sorting is happening,
-- so we can drop that sort.
rw [getElem_qsort_sort_of_lt_lo (i := mid)]
-- The `mid` element of `qsort.sort lt as' lo mid ⋯`
-- is *some* element `lo + k` of `as'` in the range `lo ≤ lo + k ≤ mid`.
have z := getElem_qsort_sort_mem lt lo mid as' ?_ mid ?_ ?_ ?_
rw [Vector.mem_extract_iff_getElem] at z
obtain ⟨k, hk, z⟩ := z
rw [← z]
clear z
-- Similarly, the `mid + 1` element on the left hand side
-- is some element `mid + 1 + k'` of `qsort.sort lt as' lo mid ⋯`
-- in the range `mid + 1 ≤ mid + 1 + k' ≤ hi`
have z := getElem_qsort_sort_mem lt (mid + 1) hi
(qsort.sort lt as' lo mid ?_ ?_) ?_ (mid + 1) ?_ ?_ ?_
rw [Vector.mem_extract_iff_getElem] at z
obtain ⟨k', hk', z⟩ := z
rw [← z]
clear z
-- And then the first stage sort on the left hand side can't have any effect,
-- as it only moves elements between `lo` and `mid` inclusive.
rw [getElem_qsort_sort_of_hi_lt]
· by_cases p : lo + k = mid
· -- Now if `lo + k = mid`,
-- the element `as'[mid + 1 + k']` is in the top part of the partitioned array,
-- and `as[lo + k]` is the pivot, so we get the inequality from the specification of `qpartition`.
grind [qpartition_spec₂]
· -- Otherwise, we use transitivity:
-- `as[lo + k']` is in the bottom part, so is strictly less than the pivot,
-- while `as'[mid + 1 + k']` is in the top, so greater than or equal to the pivot.
apply le_trans (b := as'[mid])
· grind [qpartition_spec₁]
· grind [qpartition_spec₂]
-- Various arithmetic side conditions remain from the rewriting,
-- but are now all easy to resolve.
all_goals grind
else
-- If `i < mid`, we can apply the theorem recursively replacing
-- `as` with `qsort.sort lt as' lo mid ⋯` and `lo` with `mid + 1`.
apply qsort_sort_spec lt_asymm le_trans _ _ _ (w_as := rfl) <;> grind
· -- Just an arithmetical contradiction.
grind
/--
The slice of `as.qsort lt lo hi` from `lo` to `hi` (inclusive) is sorted.
This variant states that adjacent elements are non-decreasing.
See `qsort_sorted'` for a variant about arbitrary pairs of indices.
-/
theorem qsort_sorted₁' (lt : αα → Bool) (lt_asymm : ∀ {a b}, lt a b → ¬ lt b a)
(le_trans : ∀ {a b c}, ¬ lt b a → ¬ lt c b → ¬ lt c a)
(as : Array α) (lo hi : Nat) (i) (h₁ : lo ≤ i) (h₂ : i < hi) (h₃ : i + 1 < as.size) :
¬ lt ((as.qsort lt lo hi)[i + 1]'(by grind)) ((as.qsort lt lo hi)[i]'(by grind)) := by
unfold qsort
split <;> rename_i w
· grind
· apply qsort_sort_spec lt lt_asymm le_trans (w_as := rfl) <;> grind
/--
`Array.qsort` returns a sorted array, i.e. adjacent elements are non-decreasing.
See `qsort_sorted` for a variant about arbitrary pairs of indices.
-/
theorem qsort_sorted₁ (lt : αα → Bool) (lt_asymm : ∀ {a b}, lt a b → ¬ lt b a)
(le_trans : ∀ {a b c}, ¬ lt b a → ¬ lt c b → ¬ lt c a) (as : Array α)
(i) (h : i + 1 < (qsort as lt).size) :
¬ lt (as.qsort lt)[i + 1] (as.qsort lt)[i] := by
have := qsort_sorted₁' lt lt_asymm le_trans
grind
/-- The slice of `as.qsort lt lo hi` from `lo` to `hi` (inclusive) is sorted. -/
theorem qsort_sorted' (lt : αα → Bool) (lt_asymm : ∀ {a b}, lt a b → ¬ lt b a)
(le_trans : ∀ {a b c}, ¬ lt b a → ¬ lt c b → ¬ lt c a)
(as : Array α) (lo hi : Nat) (i j) (h₁ : lo ≤ i) (h₂ : i < j) (h₃ : j ≤ hi) (h₄ : j < as.size) :
¬ lt ((as.qsort lt lo hi)[j]'(by grind)) ((as.qsort lt lo hi)[i]'(by grind)) := by
induction j with
| zero => grind
| succ j ih =>
if p : i = j then
subst p
apply qsort_sorted₁' <;> grind
else
apply le_trans (b := (as.qsort lt lo hi)[j]'(by grind))
· grind
· apply qsort_sorted₁' <;> grind
theorem qsort_sorted (lt : αα → Bool) (lt_asymm : ∀ {a b}, lt a b → ¬ lt b a)
(le_trans : ∀ {a b c}, ¬ lt b a → ¬ lt c b → ¬ lt c a) (as : Array α) :
∀ i j, (h₁ : i < j) → (h₂ : j < (qsort as lt).size) →
¬ lt (as.qsort lt)[j] (as.qsort lt)[i] := by
have := qsort_sorted' lt lt_asymm le_trans
grind
end Array