feat: Array.swap takes Nat arguments, with tactic provided proofs (#6194)

This PR changes the signature of `Array.swap`, so it takes `Nat`
arguments with tactic provided bounds checking. It also renames
`Array.swap!` to `Array.swapIfInBounds`.
This commit is contained in:
Kim Morrison 2024-11-24 18:59:57 +11:00 committed by GitHub
parent 884a9ea2ff
commit 42e98bd3c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 64 additions and 110 deletions

View file

@ -166,15 +166,15 @@ This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[extern "lean_array_fswap"]
def swap (a : Array α) (i j : @& Fin a.size) : Array α :=
def swap (a : Array α) (i j : @& Nat) (hi : i < a.size := by get_elem_tactic) (hj : j < a.size := by get_elem_tactic) : Array α :=
let v₁ := a[i]
let v₂ := a[j]
let a' := a.set i v₂
a'.set j v₁ (Nat.lt_of_lt_of_eq j.isLt (size_set a i v₂ _).symm)
a'.set j v₁ (Nat.lt_of_lt_of_eq hj (size_set a i v₂ _).symm)
@[simp] theorem size_swap (a : Array α) (i j : Fin a.size) : (a.swap i j).size = a.size := by
@[simp] theorem size_swap (a : Array α) (i j : Nat) {hi hj} : (a.swap i j hi hj).size = a.size := by
show ((a.set i a[j]).set j a[i]
(Nat.lt_of_lt_of_eq j.isLt (size_set a i a[j] _).symm)).size = a.size
(Nat.lt_of_lt_of_eq hj (size_set a i a[j] _).symm)).size = a.size
rw [size_set, size_set]
/--
@ -184,12 +184,14 @@ This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[extern "lean_array_swap"]
def swap! (a : Array α) (i j : @& Nat) : Array α :=
def swapIfInBounds (a : Array α) (i j : @& Nat) : Array α :=
if h₁ : i < a.size then
if h₂ : j < a.size then swap a ⟨i, h₁⟩ ⟨j, h₂⟩
if h₂ : j < a.size then swap a i j
else a
else a
@[deprecated swapIfInBounds (since := "2024-11-24")] abbrev swap! := @swapIfInBounds
/-! ### GetElem instance for `USize`, backed by `uget` -/
instance : GetElem (Array α) USize α fun xs i => i.toNat < xs.size where
@ -250,7 +252,7 @@ def get? (a : Array α) (i : Nat) : Option α :=
def back? (a : Array α) : Option α :=
a[a.size - 1]?
@[inline] def swapAt (a : Array α) (i : Fin a.size) (v : α) : α × Array α :=
@[inline] def swapAt (a : Array α) (i : Nat) (v : α) (hi : i < a.size := by get_elem_tactic) : α × Array α :=
let e := a[i]
let a := a.set i v
(e, a)
@ -258,7 +260,7 @@ def back? (a : Array α) : Option α :=
@[inline]
def swapAt! (a : Array α) (i : Nat) (v : α) : α × Array α :=
if h : i < a.size then
swapAt a ⟨i, h⟩ v
swapAt a i v
else
have : Inhabited (α × Array α) := ⟨(v, a)⟩
panic! ("index " ++ toString i ++ " out of bounds")
@ -747,7 +749,7 @@ where
loop (as : Array α) (i : Nat) (j : Fin as.size) :=
if h : i < j then
have := termination h
let as := as.swap ⟨i, Nat.lt_trans h j.2⟩ j
let as := as.swap i j (Nat.lt_trans h j.2)
have : j-1 < as.size := by rw [size_swap]; exact Nat.lt_of_le_of_lt (Nat.pred_le _) j.2
loop as (i+1) ⟨j-1, this⟩
else
@ -787,7 +789,7 @@ it has to backshift all elements at positions greater than `i`.-/
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def eraseIdx (a : Array α) (i : Nat) (h : i < a.size := by get_elem_tactic) : Array α :=
if h' : i + 1 < a.size then
let a' := a.swap ⟨i + 1, h'⟩ ⟨i, h⟩
let a' := a.swap (i + 1) i
a'.eraseIdx (i + 1) (by simp [a', h'])
else
a.pop
@ -834,7 +836,7 @@ def eraseP (as : Array α) (p : α → Bool) : Array α :=
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
loop (as : Array α) (j : Fin as.size) :=
if i < j then
let j' := ⟨j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2⟩
let j' : Fin as.size := ⟨j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2⟩
let as := as.swap j' j
loop as ⟨j', by rw [size_swap]; exact j'.2⟩
else

View file

@ -23,6 +23,6 @@ where
| j'+1 =>
have h' : j' < a.size := by subst j; exact Nat.lt_trans (Nat.lt_succ_self _) h
if lt a[j] a[j'] then
swapLoop (a.swap ⟨j, h⟩ ⟨j', h'⟩) j' (by rw [size_swap]; assumption; done)
swapLoop (a.swap j j') j' (by rw [size_swap]; assumption; done)
else
a

View file

@ -816,22 +816,22 @@ theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) :
private theorem fin_cast_val (e : n = n') (i : Fin n) : e ▸ i = ⟨i.1, e ▸ i.2⟩ := by cases e; rfl
theorem swap_def (a : Array α) (i j : Fin a.size) :
a.swap i j = (a.set i a[j]).set j a[i] := by
theorem swap_def (a : Array α) (i j : Nat) (hi hj) :
a.swap i j hi hj = (a.set i a[j]).set j a[i] (by simpa using hj) := by
simp [swap, fin_cast_val]
@[simp] theorem toList_swap (a : Array α) (i j : Fin a.size) :
(a.swap i j).toList = (a.toList.set i a[j]).set j a[i] := by simp [swap_def]
@[simp] theorem toList_swap (a : Array α) (i j : Nat) (hi hj) :
(a.swap i j hi hj).toList = (a.toList.set i a[j]).set j a[i] := by simp [swap_def]
theorem getElem?_swap (a : Array α) (i j : Fin a.size) (k : Nat) : (a.swap i j)[k]? =
if j = k then some a[i.1] else if i = k then some a[j.1] else a[k]? := by
theorem getElem?_swap (a : Array α) (i j : Nat) (hi hj) (k : Nat) : (a.swap i j hi hj)[k]? =
if j = k then some a[i] else if i = k then some a[j] else a[k]? := by
simp [swap_def, get?_set, ← getElem_fin_eq_getElem_toList]
@[simp] theorem swapAt_def (a : Array α) (i : Fin a.size) (v : α) :
a.swapAt i v = (a[i.1], a.set i v) := rfl
@[simp] theorem swapAt_def (a : Array α) (i : Nat) (v : α) (hi) :
a.swapAt i v hi = (a[i], a.set i v) := rfl
@[simp] theorem size_swapAt (a : Array α) (i : Fin a.size) (v : α) :
(a.swapAt i v).2.size = a.size := by simp [swapAt_def]
@[simp] theorem size_swapAt (a : Array α) (i : Nat) (v : α) (hi) :
(a.swapAt i v hi).2.size = a.size := by simp [swapAt_def]
@[simp]
theorem swapAt!_def (a : Array α) (i : Nat) (v : α) (h : i < a.size) :
@ -878,8 +878,10 @@ theorem eq_push_of_size_ne_zero {as : Array α} (h : as.size ≠ 0) :
theorem size_eq_length_toList (as : Array α) : as.size = as.toList.length := rfl
@[simp] theorem size_swap! (a : Array α) (i j) :
(a.swap! i j).size = a.size := by unfold swap!; split <;> (try split) <;> simp [size_swap]
@[simp] theorem size_swapIfInBounds (a : Array α) (i j) :
(a.swapIfInBounds i j).size = a.size := by unfold swapIfInBounds; split <;> (try split) <;> simp [size_swap]
@[deprecated size_swapIfInBounds (since := "2024-11-24")] abbrev size_swap! := @size_swapIfInBounds
@[simp] theorem size_reverse (a : Array α) : a.reverse.size = a.size := by
let rec go (as : Array α) (i j) : (reverse.loop as i j).size = as.size := by
@ -1641,28 +1643,30 @@ instance [DecidableEq α] (a : α) (as : Array α) : Decidable (a ∈ as) :=
open Fin
@[simp] theorem getElem_swap_right (a : Array α) {i j : Fin a.size} : (a.swap i j)[j.1] = a[i] := by
@[simp] theorem getElem_swap_right (a : Array α) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[j]'(by simpa using hj) = a[i] := by
simp [swap_def, getElem_set]
@[simp] theorem getElem_swap_left (a : Array α) {i j : Fin a.size} : (a.swap i j)[i.1] = a[j] := by
@[simp] theorem getElem_swap_left (a : Array α) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[i]'(by simpa using hi) = a[j] := by
simp +contextual [swap_def, getElem_set]
@[simp] theorem getElem_swap_of_ne (a : Array α) {i j : Fin a.size} (hp : p < a.size)
(hi : p ≠ i) (hj : p ≠ j) : (a.swap i j)[p]'(a.size_swap .. |>.symm ▸ hp) = a[p] := by
simp [swap_def, getElem_set, hi.symm, hj.symm]
@[simp] theorem getElem_swap_of_ne (a : Array α) {i j : Nat} {hi hj} (hp : p < a.size)
(hi' : p ≠ i) (hj' : p ≠ j) : (a.swap i j hi hj)[p]'(a.size_swap .. |>.symm ▸ hp) = a[p] := by
simp [swap_def, getElem_set, hi'.symm, hj'.symm]
theorem getElem_swap' (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < a.size) :
(a.swap i j)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by
theorem getElem_swap' (a : Array α) (i j : Nat) {hi hj} (k : Nat) (hk : k < a.size) :
(a.swap i j hi hj)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by
split
· simp_all only [getElem_swap_left]
· split <;> simp_all
theorem getElem_swap (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < (a.swap i j).size) :
(a.swap i j)[k] = if k = i then a[j] else if k = j then a[i] else a[k]'(by simp_all) := by
theorem getElem_swap (a : Array α) (i j : Nat) {hi hj}(k : Nat) (hk : k < (a.swap i j).size) :
(a.swap i j hi hj)[k] = if k = i then a[j] else if k = j then a[i] else a[k]'(by simp_all) := by
apply getElem_swap'
@[simp] theorem swap_swap (a : Array α) {i j : Fin a.size} :
(a.swap i j).swap ⟨i.1, (a.size_swap ..).symm ▸ i.2⟩ ⟨j.1, (a.size_swap ..).symm ▸ j.2⟩ = a := by
@[simp] theorem swap_swap (a : Array α) {i j : Nat} (hi hj) :
(a.swap i j hi hj).swap i j ((a.size_swap ..).symm ▸ hi) ((a.size_swap ..).symm ▸ hj) = a := by
apply ext
· simp only [size_swap]
· intros
@ -1671,7 +1675,7 @@ theorem getElem_swap (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < (a.sw
· simp_all
· split <;> simp_all
theorem swap_comm (a : Array α) {i j : Fin a.size} : a.swap i j = a.swap j i := by
theorem swap_comm (a : Array α) {i j : Nat} {hi hj} : a.swap i j hi hj = a.swap j i hj hi := by
apply ext
· simp only [size_swap]
· intros
@ -1834,8 +1838,8 @@ theorem all_toArray (p : α → Bool) (l : List α) : l.toArray.all p = l.all p
subst h
rw [all_toList]
@[simp] theorem swap_toArray (l : List α) (i j : Fin l.toArray.size) :
l.toArray.swap i j = ((l.set i l[j]).set j l[i]).toArray := by
@[simp] theorem swap_toArray (l : List α) (i j : Nat) {hi hj}:
l.toArray.swap i j hi hj = ((l.set i l[j]).set j l[i]).toArray := by
apply ext'
simp

View file

@ -13,19 +13,19 @@ namespace Array
def qpartition (as : Array α) (lt : αα → Bool) (lo hi : Nat) : Nat × Array α :=
if h : as.size = 0 then (0, as) else have : Inhabited α := ⟨as[0]'(by revert h; cases as.size <;> simp)⟩ -- TODO: remove
let mid := (lo + hi) / 2
let as := if lt (as.get! mid) (as.get! lo) then as.swap! lo mid else as
let as := if lt (as.get! hi) (as.get! lo) then as.swap! lo hi else as
let as := if lt (as.get! mid) (as.get! hi) then as.swap! mid hi else as
let as := if lt (as.get! mid) (as.get! lo) then as.swapIfInBounds lo mid else as
let as := if lt (as.get! hi) (as.get! lo) then as.swapIfInBounds lo hi else as
let as := if lt (as.get! mid) (as.get! hi) then as.swapIfInBounds mid hi else as
let pivot := as.get! hi
let rec loop (as : Array α) (i j : Nat) :=
if h : j < hi then
if lt (as.get! j) pivot then
let as := as.swap! i j
let as := as.swapIfInBounds i j
loop as (i+1) (j+1)
else
loop as i (j+1)
else
let as := as.swap! i hi
let as := as.swapIfInBounds i hi
(i, as)
termination_by hi - j
decreasing_by all_goals simp_wf; decreasing_trivial_pre_omega

View file

@ -23,16 +23,13 @@ def split (s : Subarray α) (i : Fin s.size.succ) : (Subarray α × Subarray α)
let ⟨i', isLt⟩ := i
have := s.start_le_stop
have := s.stop_le_array_size
have : i' ≤ s.stop - s.start := Nat.lt_succ.mp isLt
have : s.start + i' ≤ s.stop := by omega
have : s.start + i' ≤ s.array.size := by omega
have : s.start + i' ≤ s.stop := by
simp only [size] at isLt
omega
let pre := {s with
stop := s.start + i',
start_le_stop := by omega,
stop_le_array_size := by assumption
stop_le_array_size := by omega
}
let post := {s with
start := s.start + i'
@ -48,9 +45,7 @@ def drop (arr : Subarray α) (i : Nat) : Subarray α where
array := arr.array
start := min (arr.start + i) arr.stop
stop := arr.stop
start_le_stop := by
rw [Nat.min_def]
split <;> simp only [Nat.le_refl, *]
start_le_stop := by omega
stop_le_array_size := arr.stop_le_array_size
/--
@ -63,9 +58,7 @@ def take (arr : Subarray α) (i : Nat) : Subarray α where
stop := min (arr.start + i) arr.stop
start_le_stop := by
have := arr.start_le_stop
rw [Nat.min_def]
split <;> omega
omega
stop_le_array_size := by
have := arr.stop_le_array_size
rw [Nat.min_def]
split <;> omega
omega

View file

@ -24,20 +24,20 @@ macro:max "↑" x:term:max : term => `(UInt32.toNat $x)
| as, i, j =>
if j < hi then
if lt (as.get! ↑j) pivot then
let as := as.swap! ↑i ↑j;
let as := as.swapIfInBounds ↑i ↑j;
partitionAux lt hi pivot as (i+1) (j+1)
else
partitionAux lt hi pivot as i (j+1)
else
let as := as.swap! ↑i ↑hi;
let as := as.swapIfInBounds ↑i ↑hi;
(i, as)
set_option pp.all true
@[inline] def partition {α : Type} [Inhabited α] (as : Array α) (lt : αα → Bool) (lo hi : Idx) : Idx × Array α :=
let mid : Idx := (lo + hi) / 2;
let as := if lt (as.get! ↑mid) (as.get! ↑lo) then as.swap! ↑lo ↑mid else as;
let as := if lt (as.get! ↑hi) (as.get! ↑lo) then as.swap! ↑lo ↑hi else as;
let as := if lt (as.get! ↑mid) (as.get! ↑hi) then as.swap! ↑mid ↑hi else as;
let as := if lt (as.get! ↑mid) (as.get! ↑lo) then as.swapIfInBounds ↑lo ↑mid else as;
let as := if lt (as.get! ↑hi) (as.get! ↑lo) then as.swapIfInBounds ↑lo ↑hi else as;
let as := if lt (as.get! ↑mid) (as.get! ↑hi) then as.swapIfInBounds ↑mid ↑hi else as;
let pivot := as.get! ↑hi;
partitionAux lt hi pivot as lo lo

View file

@ -6,7 +6,7 @@ def myfun (x : Array α) (i : Fin x.size) : Array α :=
let next := 2*i.1 + 1
if h : next < x.size then
have : x.size - next < x.size - i.1 := sorry
myfun (x.swap i next,h⟩) ⟨next, (x.size_swap _ _).symm ▸ h⟩
myfun (x.swap i next) ⟨next, (x.size_swap _ _).symm ▸ h⟩
else
x
termination_by x.size - i.1

View file

@ -2,8 +2,7 @@ def Array.swaps (a : Array α) : List (Fin a.size × Fin a.size) → Array α
| [] => a
| (i, j) :: ijs =>
have : (a.swap i j).size = a.size := a.size_swap _ _
swaps (a.swap i j) (ijs.map (fun p => ⟨⟨p.1.1, this.symm ▸ p.1.2⟩, ⟨p.2.1, this.symm ▸ p.2.2⟩⟩))
swaps (a.swap i j) (ijs.map (fun p => ⟨⟨p.1.1, by simp⟩, ⟨p.2.1, by simp⟩⟩))
termination_by l => l.length
set_option maxHeartbeats 1000 in

View file

@ -11,7 +11,7 @@ info: Array.insertionSort.swapLoop.eq_1.{u_1} {α : Type u_1} (lt : αα
info: Array.insertionSort.swapLoop.eq_2.{u_1} {α : Type u_1} (lt : αα → Bool) (a : Array α) (j' : Nat)
(h : j'.succ < a.size) :
Array.insertionSort.swapLoop lt a j'.succ h =
if lt a[j'.succ] a[j'] = true then Array.insertionSort.swapLoop lt (a.swap ⟨j'.succ, h⟩ ⟨j', ⋯⟩) j' ⋯ else a
if lt a[j'.succ] a[j'] = true then Array.insertionSort.swapLoop lt (a.swap j'.succ j' h ⋯) j' ⋯ else a
-/
#guard_msgs in
#check Array.insertionSort.swapLoop.eq_2

View file

@ -103,10 +103,10 @@ def popMaxAux {lt} (self : BinaryHeap α lt) : {a' : BinaryHeap α lt // a'.size
match e: self.1.size with
| 0 => ⟨self, by simp [size, e]⟩
| n+1 =>
have h0 := by rw [e]; apply Nat.succ_pos
have hn := by rw [e]; apply Nat.lt_succ_self
have h0 : 0 < self.1.size := by rw [e]; apply Nat.succ_pos
have hn : n < self.1.size := by rw [e]; apply Nat.lt_succ_self
if hn0 : 0 < n then
let a := self.1.swap ⟨0, h0⟩ ⟨n, hn⟩ |>.pop
let a := self.1.swap 0 n |>.pop
⟨⟨heapifyDown lt a ⟨0, sorry⟩⟩,
by simp [size, a]⟩
else

View file

@ -1,39 +0,0 @@
namespace Ex1
variable (a : Nat) (i : Fin a) (h : 1 = a)
example : i < a := h ▸ i.2 -- `▸` uses `subst` here
end Ex1
namespace Ex2
def heapifyDown' (a : Array α) (i : Fin a.size) : Array α := sorry
def heapifyDown (a : Array α) (i : Fin a.size) : Array α :=
heapifyDown' a ⟨i.1, a.size_swap i i ▸ i.2⟩ -- Error, failed to compute motive, `subst` is not applicable here
end Ex2
namespace Ex3
def heapifyDown (a : Array α) (i : Fin a.size) : Array α :=
have : i < i := sorry
heapifyDown a ⟨i.1, a.size_swap i i ▸ i.2⟩ -- Error, failed to compute motive, `subst` is not applicable here
termination_by i.1
decreasing_by assumption
end Ex3
namespace Ex4
def heapifyDown (lt : αα → Bool) (a : Array α) (i : Fin a.size) : Array α :=
let left := 2 * i.1 + 1
let right := left + 1
have left_le : i ≤ left := sorry
have right_le : i ≤ right := sorry
have i_le : i ≤ i := Nat.le_refl _
have j : {j : Fin a.size // i ≤ j} := if h : left < a.size then
if lt a[i] a[left] then ⟨⟨left, h⟩, left_le⟩ else ⟨i, i_le⟩ else ⟨i, i_le⟩
have j := if h : right < a.size then
if lt a[j.1.1] a[right] then ⟨⟨right, h⟩, right_le⟩ else j else j
if h : i ≠ j then
let a' := a.swap i j
have : a'.size - j < a.size - i := sorry
heapifyDown lt a' ⟨j.1.1, a.size_swap i j ▸ j.1.2⟩ -- Error, failed to compute motive, `subst` is not applicable here
else
a
termination_by a.size - i.1
decreasing_by assumption
end Ex4

View file

@ -1,5 +0,0 @@
substBadMotive.lean:7:4-7:16: warning: declaration uses 'sorry'
substBadMotive.lean:9:23-9:44: error: invalid `▸` notation, failed to compute motive for the substitution
substBadMotive.lean:15:22-15:43: error: invalid `▸` notation, failed to compute motive for the substitution
substBadMotive.lean:13:0-17:24: error: well-founded recursion cannot be used, 'Ex3.heapifyDown' does not take any (non-fixed) arguments
substBadMotive.lean:34:30-34:53: error: invalid `▸` notation, failed to compute motive for the substitution