feat: Vector.getElem_flatMap (#6661)
This PR adds array indexing lemmas for `Vector.flatMap`. (These were not available for `List` and `Array` due to variable lengths.)
This commit is contained in:
parent
80ddbf45eb
commit
f4c9934171
2 changed files with 63 additions and 0 deletions
|
|
@ -622,6 +622,14 @@ protected theorem pos_of_mul_pos_right {a b : Nat} (h : 0 < a * b) : 0 < a := by
|
|||
0 < a * b ↔ 0 < a :=
|
||||
⟨Nat.pos_of_mul_pos_right, fun w => Nat.mul_pos w h⟩
|
||||
|
||||
protected theorem pos_of_lt_mul_left {a b c : Nat} (h : a < b * c) : 0 < c := by
|
||||
replace h : 0 < b * c := by omega
|
||||
exact Nat.pos_of_mul_pos_left h
|
||||
|
||||
protected theorem pos_of_lt_mul_right {a b c : Nat} (h : a < b * c) : 0 < b := by
|
||||
replace h : 0 < b * c := by omega
|
||||
exact Nat.pos_of_mul_pos_right h
|
||||
|
||||
/-! ### div/mod -/
|
||||
|
||||
theorem mod_two_eq_zero_or_one (n : Nat) : n % 2 = 0 ∨ n % 2 = 1 :=
|
||||
|
|
|
|||
|
|
@ -1456,6 +1456,44 @@ theorem append_eq_map_iff {f : α → β} :
|
|||
mk (L.map toArray).flatten (by simp [Function.comp_def, Array.map_const', h]) := by
|
||||
simp [flatten]
|
||||
|
||||
@[simp] theorem getElem_flatten (l : Vector (Vector β m) n) (i : Nat) (hi : i < n * m) :
|
||||
l.flatten[i] =
|
||||
haveI : i / m < n := by rwa [Nat.div_lt_iff_lt_mul (Nat.pos_of_lt_mul_left hi)]
|
||||
haveI : i % m < m := Nat.mod_lt _ (Nat.pos_of_lt_mul_left hi)
|
||||
l[i / m][i % m] := by
|
||||
rcases l with ⟨⟨l⟩, rfl⟩
|
||||
simp only [flatten_mk, List.map_toArray, getElem_mk, List.getElem_toArray, Array.flatten_toArray]
|
||||
induction l generalizing i with
|
||||
| nil => simp at hi
|
||||
| cons a l ih =>
|
||||
simp only [List.map_cons, List.map_map, List.flatten_cons]
|
||||
by_cases h : i < m
|
||||
· rw [List.getElem_append_left (by simpa)]
|
||||
have h₁ : i / m = 0 := Nat.div_eq_of_lt h
|
||||
have h₂ : i % m = i := Nat.mod_eq_of_lt h
|
||||
simp [h₁, h₂]
|
||||
· have h₁ : a.toList.length ≤ i := by simp; omega
|
||||
rw [List.getElem_append_right h₁]
|
||||
simp only [Array.length_toList, size_toArray]
|
||||
specialize ih (i - m) (by simp_all [Nat.add_one_mul]; omega)
|
||||
have h₂ : i / m = (i - m) / m + 1 := by
|
||||
conv => lhs; rw [show i = i - m + m by omega]
|
||||
rw [Nat.add_div_right]
|
||||
exact Nat.pos_of_lt_mul_left hi
|
||||
simp only [Array.length_toList, size_toArray] at h₁
|
||||
have h₃ : (i - m) % m = i % m := (Nat.mod_eq_sub_mod h₁).symm
|
||||
simp_all
|
||||
|
||||
theorem getElem?_flatten (l : Vector (Vector β m) n) (i : Nat) :
|
||||
l.flatten[i]? =
|
||||
if hi : i < n * m then
|
||||
haveI : i / m < n := by rwa [Nat.div_lt_iff_lt_mul (Nat.pos_of_lt_mul_left hi)]
|
||||
haveI : i % m < m := Nat.mod_lt _ (Nat.pos_of_lt_mul_left hi)
|
||||
some l[i / m][i % m]
|
||||
else
|
||||
none := by
|
||||
simp [getElem?_def]
|
||||
|
||||
@[simp] theorem flatten_singleton (l : Vector α n) : #v[l].flatten = l.cast (by simp) := by
|
||||
simp [flatten]
|
||||
|
||||
|
|
@ -1542,6 +1580,23 @@ theorem flatMap_def (l : Vector α n) (f : α → Vector β m) : l.flatMap f = f
|
|||
rcases l with ⟨l, rfl⟩
|
||||
simp [Array.flatMap_def, Function.comp_def]
|
||||
|
||||
@[simp] theorem getElem_flatMap (l : Vector α n) (f : α → Vector β m) (i : Nat) (hi : i < n * m) :
|
||||
(l.flatMap f)[i] =
|
||||
haveI : i / m < n := by rwa [Nat.div_lt_iff_lt_mul (Nat.pos_of_lt_mul_left hi)]
|
||||
haveI : i % m < m := Nat.mod_lt _ (Nat.pos_of_lt_mul_left hi)
|
||||
(f (l[i / m]))[i % m] := by
|
||||
rw [flatMap_def, getElem_flatten, getElem_map]
|
||||
|
||||
theorem getElem?_flatMap (l : Vector α n) (f : α → Vector β m) (i : Nat) :
|
||||
(l.flatMap f)[i]? =
|
||||
if hi : i < n * m then
|
||||
haveI : i / m < n := by rwa [Nat.div_lt_iff_lt_mul (Nat.pos_of_lt_mul_left hi)]
|
||||
haveI : i % m < m := Nat.mod_lt _ (Nat.pos_of_lt_mul_left hi)
|
||||
some ((f (l[i / m]))[i % m])
|
||||
else
|
||||
none := by
|
||||
simp [getElem?_def]
|
||||
|
||||
@[simp] theorem flatMap_id (l : Vector (Vector α m) n) : l.flatMap id = l.flatten := by simp [flatMap_def]
|
||||
|
||||
@[simp] theorem flatMap_id' (l : Vector (Vector α m) n) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue