From f4c9934171a22d376fb7c318ce2dcf80ab0eaf0e Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Thu, 16 Jan 2025 17:33:54 +1100 Subject: [PATCH] feat: Vector.getElem_flatMap (#6661) This PR adds array indexing lemmas for `Vector.flatMap`. (These were not available for `List` and `Array` due to variable lengths.) --- src/Init/Data/Nat/Lemmas.lean | 8 +++++ src/Init/Data/Vector/Lemmas.lean | 55 ++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/src/Init/Data/Nat/Lemmas.lean b/src/Init/Data/Nat/Lemmas.lean index 2fd5a8d700..6cd491d64b 100644 --- a/src/Init/Data/Nat/Lemmas.lean +++ b/src/Init/Data/Nat/Lemmas.lean @@ -622,6 +622,14 @@ protected theorem pos_of_mul_pos_right {a b : Nat} (h : 0 < a * b) : 0 < a := by 0 < a * b ↔ 0 < a := ⟨Nat.pos_of_mul_pos_right, fun w => Nat.mul_pos w h⟩ +protected theorem pos_of_lt_mul_left {a b c : Nat} (h : a < b * c) : 0 < c := by + replace h : 0 < b * c := by omega + exact Nat.pos_of_mul_pos_left h + +protected theorem pos_of_lt_mul_right {a b c : Nat} (h : a < b * c) : 0 < b := by + replace h : 0 < b * c := by omega + exact Nat.pos_of_mul_pos_right h + /-! ### div/mod -/ theorem mod_two_eq_zero_or_one (n : Nat) : n % 2 = 0 ∨ n % 2 = 1 := diff --git a/src/Init/Data/Vector/Lemmas.lean b/src/Init/Data/Vector/Lemmas.lean index 17964a63ef..91016438aa 100644 --- a/src/Init/Data/Vector/Lemmas.lean +++ b/src/Init/Data/Vector/Lemmas.lean @@ -1456,6 +1456,44 @@ theorem append_eq_map_iff {f : α → β} : mk (L.map toArray).flatten (by simp [Function.comp_def, Array.map_const', h]) := by simp [flatten] +@[simp] theorem getElem_flatten (l : Vector (Vector β m) n) (i : Nat) (hi : i < n * m) : + l.flatten[i] = + haveI : i / m < n := by rwa [Nat.div_lt_iff_lt_mul (Nat.pos_of_lt_mul_left hi)] + haveI : i % m < m := Nat.mod_lt _ (Nat.pos_of_lt_mul_left hi) + l[i / m][i % m] := by + rcases l with ⟨⟨l⟩, rfl⟩ + simp only [flatten_mk, List.map_toArray, getElem_mk, List.getElem_toArray, Array.flatten_toArray] + induction l generalizing i with + | nil => simp at hi + | cons a l ih => + simp only [List.map_cons, List.map_map, List.flatten_cons] + by_cases h : i < m + · rw [List.getElem_append_left (by simpa)] + have h₁ : i / m = 0 := Nat.div_eq_of_lt h + have h₂ : i % m = i := Nat.mod_eq_of_lt h + simp [h₁, h₂] + · have h₁ : a.toList.length ≤ i := by simp; omega + rw [List.getElem_append_right h₁] + simp only [Array.length_toList, size_toArray] + specialize ih (i - m) (by simp_all [Nat.add_one_mul]; omega) + have h₂ : i / m = (i - m) / m + 1 := by + conv => lhs; rw [show i = i - m + m by omega] + rw [Nat.add_div_right] + exact Nat.pos_of_lt_mul_left hi + simp only [Array.length_toList, size_toArray] at h₁ + have h₃ : (i - m) % m = i % m := (Nat.mod_eq_sub_mod h₁).symm + simp_all + +theorem getElem?_flatten (l : Vector (Vector β m) n) (i : Nat) : + l.flatten[i]? = + if hi : i < n * m then + haveI : i / m < n := by rwa [Nat.div_lt_iff_lt_mul (Nat.pos_of_lt_mul_left hi)] + haveI : i % m < m := Nat.mod_lt _ (Nat.pos_of_lt_mul_left hi) + some l[i / m][i % m] + else + none := by + simp [getElem?_def] + @[simp] theorem flatten_singleton (l : Vector α n) : #v[l].flatten = l.cast (by simp) := by simp [flatten] @@ -1542,6 +1580,23 @@ theorem flatMap_def (l : Vector α n) (f : α → Vector β m) : l.flatMap f = f rcases l with ⟨l, rfl⟩ simp [Array.flatMap_def, Function.comp_def] +@[simp] theorem getElem_flatMap (l : Vector α n) (f : α → Vector β m) (i : Nat) (hi : i < n * m) : + (l.flatMap f)[i] = + haveI : i / m < n := by rwa [Nat.div_lt_iff_lt_mul (Nat.pos_of_lt_mul_left hi)] + haveI : i % m < m := Nat.mod_lt _ (Nat.pos_of_lt_mul_left hi) + (f (l[i / m]))[i % m] := by + rw [flatMap_def, getElem_flatten, getElem_map] + +theorem getElem?_flatMap (l : Vector α n) (f : α → Vector β m) (i : Nat) : + (l.flatMap f)[i]? = + if hi : i < n * m then + haveI : i / m < n := by rwa [Nat.div_lt_iff_lt_mul (Nat.pos_of_lt_mul_left hi)] + haveI : i % m < m := Nat.mod_lt _ (Nat.pos_of_lt_mul_left hi) + some ((f (l[i / m]))[i % m]) + else + none := by + simp [getElem?_def] + @[simp] theorem flatMap_id (l : Vector (Vector α m) n) : l.flatMap id = l.flatten := by simp [flatMap_def] @[simp] theorem flatMap_id' (l : Vector (Vector α m) n) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def]