From 3bf95e9b585ceca4e88baeda8b92a44f485cffdf Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Tue, 20 May 2025 15:28:29 +1000 Subject: [PATCH] feat: add List/Array/Vector.ofFnM (#8389) This PR adds the `List/Array/Vector.ofFnM`, the monadic analogues of `ofFn`, along with basic theory. At the same time we pave some potholes in nearby API. --------- Co-authored-by: Eric Wieser --- src/Init/Data/Array/Basic.lean | 6 ++ src/Init/Data/Array/Lemmas.lean | 53 +++++++++++-- src/Init/Data/Array/Monadic.lean | 10 +++ src/Init/Data/Array/OfFn.lean | 82 ++++++++++++++++++++ src/Init/Data/Fin/Fold.lean | 125 +++++++++++++++++++++++++++--- src/Init/Data/Fin/Lemmas.lean | 14 ++++ src/Init/Data/List/FinRange.lean | 50 +++++++++++- src/Init/Data/List/Lemmas.lean | 5 ++ src/Init/Data/List/Monadic.lean | 12 ++- src/Init/Data/List/OfFn.lean | 93 +++++++++++++++++++++- src/Init/Data/List/ToArray.lean | 6 -- src/Init/Data/Nat/Fold.lean | 20 +++-- src/Init/Data/Vector/Basic.lean | 2 + src/Init/Data/Vector/Lemmas.lean | 10 +-- src/Init/Data/Vector/Monadic.lean | 5 ++ src/Init/Data/Vector/OfFn.lean | 119 ++++++++++++++++++++++++++++ 16 files changed, 571 insertions(+), 41 deletions(-) diff --git a/src/Init/Data/Array/Basic.lean b/src/Init/Data/Array/Basic.lean index 1f0b7d5dc1..38a721356f 100644 --- a/src/Init/Data/Array/Basic.lean +++ b/src/Init/Data/Array/Basic.lean @@ -112,6 +112,10 @@ theorem mem_def {a : α} {as : Array α} : a ∈ as ↔ a ∈ as.toList := rw [Array.mem_def, ← getElem_toList] apply List.getElem_mem +@[simp] theorem emptyWithCapacity_eq {α n} : @emptyWithCapacity α n = #[] := rfl + +@[simp] theorem mkEmpty_eq {α n} : @mkEmpty α n = #[] := rfl + end Array namespace List @@ -334,6 +338,8 @@ def ofFn {n} (f : Fin n → α) : Array α := go 0 (emptyWithCapacity n) where if h : i < n then go (i+1) (acc.push (f ⟨i, h⟩)) else acc decreasing_by simp_wf; decreasing_trivial_pre_omega +-- See also `Array.ofFnM` defined in `Init.Data.Array.OfFn`. + /-- Constructs an array that contains all the numbers from `0` to `n`, exclusive. diff --git a/src/Init/Data/Array/Lemmas.lean b/src/Init/Data/Array/Lemmas.lean index 130a637358..875ef9c05e 100644 --- a/src/Init/Data/Array/Lemmas.lean +++ b/src/Init/Data/Array/Lemmas.lean @@ -61,11 +61,6 @@ theorem toArray_eq : List.toArray as = xs ↔ as = xs.toList := by @[grind] theorem size_empty : (#[] : Array α).size = 0 := rfl -@[simp] theorem emptyWithCapacity_eq {α n} : @emptyWithCapacity α n = #[] := rfl - -@[deprecated emptyWithCapacity_eq (since := "2025-03-12")] -theorem mkEmpty_eq {α n} : @mkEmpty α n = #[] := rfl - /-! ### size -/ @[grind →] theorem eq_empty_of_size_eq_zero (h : xs.size = 0) : xs = #[] := by @@ -247,6 +242,12 @@ theorem back?_pop {xs : Array α} : /-! ### push -/ +@[simp] theorem push_empty : #[].push x = #[x] := rfl + +@[simp] theorem toList_push {xs : Array α} {x : α} : (xs.push x).toList = xs.toList ++ [x] := by + rcases xs with ⟨xs⟩ + simp + @[simp] theorem push_ne_empty {a : α} {xs : Array α} : xs.push a ≠ #[] := by cases xs simp @@ -3266,6 +3267,22 @@ rather than `(arr.push a).size` as the argument. (xs.push a).foldrM f init start = f a init >>= xs.foldrM f := by simp [← foldrM_push, h] +@[simp, grind] theorem _root_.List.foldrM_push_eq_append [Monad m] [LawfulMonad m] {l : List α} {f : α → m β} {xs : Array β} : + l.foldrM (fun x xs => xs.push <$> f x) xs = do return xs ++ (← l.reverse.mapM f).toArray := by + induction l with + | nil => simp + | cons a l ih => + simp [ih] + congr 1 + funext l' + congr 1 + funext x + simp + +@[simp, grind] theorem _root_.List.foldlM_push_eq_append [Monad m] [LawfulMonad m] {l : List α} {f : α → m β} {xs : Array β} : + l.foldlM (fun xs x => xs.push <$> f x) xs = do return xs ++ (← l.mapM f).toArray := by + induction l generalizing xs <;> simp [*] + /-! ### foldl / foldr -/ @[grind] theorem foldl_empty {f : β → α → β} {init : β} : (#[].foldl f init) = init := rfl @@ -3362,6 +3379,32 @@ rather than `(arr.push a).size` as the argument. rcases as with ⟨as⟩ simp +@[simp, grind] theorem _root_.List.foldr_push_eq_append {l : List α} {f : α → β} {xs : Array β} : + l.foldr (fun x xs => xs.push (f x)) xs = xs ++ (l.reverse.map f).toArray := by + induction l <;> simp [*] + +/-- Variant of `List.foldr_push_eq_append` specialized to `f = id`. -/ +@[simp, grind] theorem _root_.List.foldr_push_eq_append' {l : List α} {xs : Array α} : + l.foldr (fun x xs => xs.push x) xs = xs ++ l.reverse.toArray := by + induction l <;> simp [*] + +@[simp, grind] theorem _root_.List.foldl_push_eq_append {l : List α} {f : α → β} {xs : Array β} : + l.foldl (fun xs x => xs.push (f x)) xs = xs ++ (l.map f).toArray := by + induction l generalizing xs <;> simp [*] + +/-- Variant of `List.foldl_push_eq_append` specialized to `f = id`. -/ +@[simp, grind] theorem _root_.List.foldl_push_eq_append' {l : List α} {xs : Array α} : + l.foldl (fun xs x => xs.push x) xs = xs ++ l.toArray := by + simpa using List.foldl_push_eq_append (f := id) + +@[deprecated _root_.List.foldl_push_eq_append' (since := "2025-05-18")] +theorem _root_.List.foldl_push {l : List α} {as : Array α} : l.foldl Array.push as = as ++ l.toArray := by + induction l generalizing as <;> simp [*] + +@[deprecated _root_.List.foldr_push_eq_append' (since := "2025-05-18")] +theorem _root_.List.foldr_push {l : List α} {as : Array α} : l.foldr (fun a bs => push bs a) as = as ++ l.reverse.toArray := by + rw [List.foldr_eq_foldl_reverse, List.foldl_push_eq_append'] + @[simp, grind] theorem foldr_append_eq_append {xs : Array α} {f : α → Array β} {ys : Array β} : xs.foldr (f · ++ ·) ys = (xs.map f).flatten ++ ys := by rcases xs with ⟨xs⟩ diff --git a/src/Init/Data/Array/Monadic.lean b/src/Init/Data/Array/Monadic.lean index 33acd55e1d..345ce1a64c 100644 --- a/src/Init/Data/Array/Monadic.lean +++ b/src/Init/Data/Array/Monadic.lean @@ -25,6 +25,11 @@ open Nat /-! ## Monadic operations -/ +@[simp] theorem map_toList_inj [Monad m] [LawfulMonad m] + {xs : m (Array α)} {ys : m (Array α)} : + toList <$> xs = toList <$> ys ↔ xs = ys := + _root_.map_inj_right (by simp) + /-! ### mapM -/ @[simp] theorem mapM_pure [Monad m] [LawfulMonad m] {xs : Array α} {f : α → β} : @@ -34,6 +39,11 @@ open Nat @[simp] theorem mapM_id {xs : Array α} {f : α → Id β} : xs.mapM f = xs.map f := mapM_pure +@[simp] theorem mapM_map [Monad m] [LawfulMonad m] {f : α → β} {g : β → m γ} {xs : Array α} : + (xs.map f).mapM g = xs.mapM (g ∘ f) := by + rcases xs with ⟨xs⟩ + simp + @[simp] theorem mapM_append [Monad m] [LawfulMonad m] {f : α → m β} {xs ys : Array α} : (xs ++ ys).mapM f = (return (← xs.mapM f) ++ (← ys.mapM f)) := by rcases xs with ⟨xs⟩ diff --git a/src/Init/Data/Array/OfFn.lean b/src/Init/Data/Array/OfFn.lean index c875d63b44..26b340f20c 100644 --- a/src/Init/Data/Array/OfFn.lean +++ b/src/Init/Data/Array/OfFn.lean @@ -8,7 +8,9 @@ module prelude import all Init.Data.Array.Basic import Init.Data.Array.Lemmas +import Init.Data.Array.Monadic import Init.Data.List.OfFn +import Init.Data.List.FinRange /-! # Theorems about `Array.ofFn` @@ -19,6 +21,8 @@ set_option linter.indexVariables true -- Enforce naming conventions for index va namespace Array +/-! ### ofFn -/ + @[simp] theorem ofFn_zero {f : Fin 0 → α} : ofFn f = #[] := by simp [ofFn, ofFn.go] @@ -32,12 +36,23 @@ theorem ofFn_succ {f : Fin (n+1) → α} : intro h₃ simp only [show i = n by omega] +theorem ofFn_add {n m} {f : Fin (n + m) → α} : + ofFn f = (ofFn (fun i => f (i.castLE (Nat.le_add_right n m)))) ++ (ofFn (fun i => f (i.natAdd n))) := by + induction m with + | zero => simp + | succ m ih => simp [ofFn_succ, ih] + @[simp] theorem _root_.List.toArray_ofFn {f : Fin n → α} : (List.ofFn f).toArray = Array.ofFn f := by ext <;> simp @[simp] theorem toList_ofFn {f : Fin n → α} : (Array.ofFn f).toList = List.ofFn f := by apply List.ext_getElem <;> simp +theorem ofFn_succ' {f : Fin (n+1) → α} : + ofFn f = #[f 0] ++ ofFn (fun i => f i.succ) := by + apply Array.toList_inj.mp + simp [List.ofFn_succ] + @[simp] theorem ofFn_eq_empty_iff {f : Fin n → α} : ofFn f = #[] ↔ n = 0 := by rw [← Array.toList_inj] @@ -52,4 +67,71 @@ theorem mem_ofFn {n} {f : Fin n → α} {a : α} : a ∈ ofFn f ↔ ∃ i, f i = · rintro ⟨i, rfl⟩ apply mem_of_getElem (i := i) <;> simp +/-! ### ofFnM -/ + +/-- Construct (in a monadic context) an array by applying a monadic function to each index. -/ +def ofFnM {n} [Monad m] (f : Fin n → m α) : m (Array α) := + Fin.foldlM n (fun xs i => xs.push <$> f i) (Array.emptyWithCapacity n) + +@[simp] +theorem ofFnM_zero [Monad m] {f : Fin 0 → m α} : ofFnM f = pure #[] := by + simp [ofFnM] + +theorem ofFnM_succ' {n} [Monad m] [LawfulMonad m] {f : Fin (n + 1) → m α} : + ofFnM f = (do + let a ← f 0 + let as ← ofFnM fun i => f i.succ + pure (#[a] ++ as)) := by + simp [ofFnM, Fin.foldlM_eq_finRange_foldlM, List.foldlM_push_eq_append, List.finRange_succ, Function.comp_def] + +theorem ofFnM_succ {n} [Monad m] [LawfulMonad m] {f : Fin (n + 1) → m α} : + ofFnM f = (do + let as ← ofFnM fun i => f i.castSucc + let a ← f (Fin.last n) + pure (as.push a)) := by + simp [ofFnM, Fin.foldlM_succ_last] + +theorem ofFnM_add {n m} [Monad m] [LawfulMonad m] {f : Fin (n + k) → m α} : + ofFnM f = (do + let as ← ofFnM fun i : Fin n => f (i.castLE (Nat.le_add_right n k)) + let bs ← ofFnM fun i : Fin k => f (i.natAdd n) + pure (as ++ bs)) := by + induction k with + | zero => simp + | succ k ih => + simp only [ofFnM_succ, Nat.add_eq, ih, Fin.castSucc_castLE, Fin.castSucc_natAdd, bind_pure_comp, + bind_assoc, bind_map_left, Fin.natAdd_last, map_bind, Functor.map_map] + congr 1 + funext xs + congr 1 + funext ys + congr 1 + funext x + simp + +@[simp] theorem toList_ofFnM [Monad m] [LawfulMonad m] {f : Fin n → m α} : + toList <$> ofFnM f = List.ofFnM f := by + induction n with + | zero => simp + | succ n ih => simp [ofFnM_succ, List.ofFnM_succ_last, ← ih] + +@[simp] +theorem ofFnM_pure_comp [Monad m] [LawfulMonad m] {n} {f : Fin n → α} : + ofFnM (pure ∘ f) = (pure (ofFn f) : m (Array α)) := by + apply Array.map_toList_inj.mp + simp + +-- Variant of `ofFnM_pure_comp` using a lambda. +-- This is not marked a `@[simp]` as it would match on every occurrence of `ofFnM`. +theorem ofFnM_pure [Monad m] [LawfulMonad m] {n} {f : Fin n → α} : + ofFnM (fun i => pure (f i)) = (pure (ofFn f) : m (Array α)) := + ofFnM_pure_comp + +@[simp, grind =] theorem idRun_ofFnM {f : Fin n → Id α} : + Id.run (ofFnM f) = ofFn (fun i => Id.run (f i)) := by + unfold Id.run + induction n with + | zero => simp + | succ n ih => simp [ofFnM_succ', ofFn_succ', ih] + end Array diff --git a/src/Init/Data/Fin/Fold.lean b/src/Init/Data/Fin/Fold.lean index 4d4ccd83d2..1e6d3b4a11 100644 --- a/src/Init/Data/Fin/Fold.lean +++ b/src/Init/Data/Fin/Fold.lean @@ -100,6 +100,11 @@ Fin.foldrM n f xₙ = do /-! ### foldlM -/ +@[congr] theorem foldlM_congr [Monad m] {n k : Nat} (w : n = k) (f : α → Fin n → m α) : + foldlM n f = foldlM k (fun x i => f x (i.cast w.symm)) := by + subst w + rfl + theorem foldlM_loop_lt [Monad m] (f : α → Fin n → m α) (x) (h : i < n) : foldlM.loop n f x i = f x ⟨i, h⟩ >>= (foldlM.loop n f . (i+1)) := by rw [foldlM.loop, dif_pos h] @@ -120,14 +125,49 @@ theorem foldlM_loop [Monad m] (f : α → Fin (n+1) → m α) (x) (h : i < n+1) rw [foldlM_loop_eq, foldlM_loop_eq] termination_by n - i -@[simp] theorem foldlM_zero [Monad m] (f : α → Fin 0 → m α) (x) : foldlM 0 f x = pure x := - foldlM_loop_eq .. +@[simp] theorem foldlM_zero [Monad m] (f : α → Fin 0 → m α) : foldlM 0 f = pure := by + funext x + exact foldlM_loop_eq .. -theorem foldlM_succ [Monad m] (f : α → Fin (n+1) → m α) (x) : - foldlM (n+1) f x = f x 0 >>= foldlM n (fun x j => f x j.succ) := foldlM_loop .. +theorem foldlM_succ [Monad m] (f : α → Fin (n+1) → m α) : + foldlM (n+1) f = fun x => f x 0 >>= foldlM n (fun x j => f x j.succ) := by + funext x + exact foldlM_loop .. + +/-- Variant of `foldlM_succ` that splits off `Fin.last n` rather than `0`. -/ +theorem foldlM_succ_last [Monad m] [LawfulMonad m] (f : α → Fin (n+1) → m α) : + foldlM (n+1) f = fun x => foldlM n (fun x j => f x j.castSucc) x >>= (f · (Fin.last n)) := by + funext x + induction n generalizing x with + | zero => + simp [foldlM_succ] + | succ n ih => + rw [foldlM_succ] + conv => rhs; rw [foldlM_succ] + simp only [castSucc_zero, castSucc_succ, bind_assoc] + congr 1 + funext x + rw [ih] + simp + +theorem foldlM_add [Monad m] [LawfulMonad m] (f : α → Fin (n + k) → m α) : + foldlM (n + k) f = + fun x => foldlM n (fun x i => f x (i.castLE (Nat.le_add_right n k))) x >>= foldlM k (fun x i => f x (i.natAdd n)) := by + induction k with + | zero => + funext x + simp + | succ k ih => + funext x + simp [foldlM_succ_last, ← Nat.add_assoc, ih] /-! ### foldrM -/ +@[congr] theorem foldrM_congr [Monad m] {n k : Nat} (w : n = k) (f : Fin n → α → m α) : + foldrM n f = foldrM k (fun i => f (i.cast w.symm)) := by + subst w + rfl + theorem foldrM_loop_zero [Monad m] (f : Fin n → α → m α) (x) : foldrM.loop n f ⟨0, Nat.zero_le _⟩ x = pure x := by rw [foldrM.loop] @@ -145,19 +185,47 @@ theorem foldrM_loop [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x conv => rhs; rw [←bind_pure (f 0 x)] congr funext - try simp only [foldrM.loop] -- the try makes this proof work with and without opaque wf rec + simp [foldrM_loop_zero] | succ i ih => rw [foldrM_loop_succ, foldrM_loop_succ, bind_assoc] congr; funext; exact ih .. -@[simp] theorem foldrM_zero [Monad m] (f : Fin 0 → α → m α) (x) : foldrM 0 f x = pure x := - foldrM_loop_zero .. +@[simp] theorem foldrM_zero [Monad m] (f : Fin 0 → α → m α) : foldrM 0 f = pure := by + funext x + exact foldrM_loop_zero .. -theorem foldrM_succ [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x) : - foldrM (n+1) f x = foldrM n (fun i => f i.succ) x >>= f 0 := foldrM_loop .. +theorem foldrM_succ [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) : + foldrM (n+1) f = fun x => foldrM n (fun i => f i.succ) x >>= f 0 := by + funext x + exact foldrM_loop .. + +theorem foldrM_succ_last [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) : + foldrM (n+1) f = fun x => f (Fin.last n) x >>= foldrM n (fun i => f i.castSucc) := by + funext x + induction n generalizing x with + | zero => simp [foldrM_succ] + | succ n ih => + rw [foldrM_succ] + conv => rhs; rw [foldrM_succ] + simp [ih] + +theorem foldrM_add [Monad m] [LawfulMonad m] (f : Fin (n + k) → α → m α) : + foldrM (n + k) f = + fun x => foldrM k (fun i => f (i.natAdd n)) x >>= foldrM n (fun i => f (i.castLE (Nat.le_add_right n k))) := by + induction k with + | zero => + simp + | succ k ih => + funext x + simp [foldrM_succ_last, ← Nat.add_assoc, ih] /-! ### foldl -/ +@[congr] theorem foldl_congr {n k : Nat} (w : n = k) (f : α → Fin n → α) : + foldl n f = foldl k (fun x i => f x (i.cast w.symm)) := by + subst w + rfl + theorem foldl_loop_lt (f : α → Fin n → α) (x) (h : i < n) : foldl.loop n f x i = foldl.loop n f (f x ⟨i, h⟩) (i+1) := by rw [foldl.loop, dif_pos h] @@ -187,14 +255,34 @@ theorem foldl_succ_last (f : α → Fin (n+1) → α) (x) : rw [foldl_succ] induction n generalizing x with | zero => simp [foldl_succ, Fin.last] - | succ n ih => rw [foldl_succ, ih (f · ·.succ), foldl_succ]; simp [succ_castSucc] + | succ n ih => rw [foldl_succ, ih (f · ·.succ), foldl_succ]; simp + +theorem foldl_add (f : α → Fin (n + m) → α) (x) : + foldl (n + m) f x = + foldl m (fun x i => f x (i.natAdd n)) + (foldl n (fun x i => f x (i.castLE (Nat.le_add_right n m))) x):= by + induction m with + | zero => simp + | succ m ih => simp [foldl_succ_last, ih, ← Nat.add_assoc] theorem foldl_eq_foldlM (f : α → Fin n → α) (x) : foldl n f x = foldlM (m:=Id) n f x := by induction n generalizing x <;> simp [foldl_succ, foldlM_succ, *] +-- This is not marked `@[simp]` as it would match on every occurrence of `foldlM`. +theorem foldlM_pure [Monad m] [LawfulMonad m] {n} {f : α → Fin n → α} : + foldlM n (fun x i => pure (f x i)) x = (pure (foldl n f x) : m α) := by + induction n generalizing x with + | zero => simp + | succ n ih => simp [foldlM_succ, foldl_succ, ih] + /-! ### foldr -/ +@[congr] theorem foldr_congr {n k : Nat} (w : n = k) (f : Fin n → α → α) : + foldr n f = foldr k (fun i => f (i.cast w.symm)) := by + subst w + rfl + theorem foldr_loop_zero (f : Fin n → α → α) (x) : foldr.loop n f 0 (Nat.zero_le _) x = x := by rw [foldr.loop] @@ -220,7 +308,15 @@ theorem foldr_succ_last (f : Fin (n+1) → α → α) (x) : foldr (n+1) f x = foldr n (f ·.castSucc) (f (last n) x) := by induction n generalizing x with | zero => simp [foldr_succ, Fin.last] - | succ n ih => rw [foldr_succ, ih (f ·.succ), foldr_succ]; simp [succ_castSucc] + | succ n ih => rw [foldr_succ, ih (f ·.succ), foldr_succ]; simp + +theorem foldr_add (f : Fin (n + m) → α → α) (x) : + foldr (n + m) f x = + foldr n (fun i => f (i.castLE (Nat.le_add_right n m))) + (foldr m (fun i => f (i.natAdd n)) x) := by + induction m generalizing x with + | zero => simp + | succ m ih => simp [foldr_succ_last, ih, ← Nat.add_assoc] theorem foldr_eq_foldrM (f : Fin n → α → α) (x) : foldr n f x = foldrM (m:=Id) n f x := by @@ -238,4 +334,11 @@ theorem foldr_rev (f : α → Fin n → α) (x) : | zero => simp | succ n ih => rw [foldl_succ_last, foldr_succ, ← ih]; simp [rev_succ] +-- This is not marked `@[simp]` as it would match on every occurrence of `foldrM`. +theorem foldrM_pure [Monad m] [LawfulMonad m] {n} {f : Fin n → α → α} : + foldrM n (fun i x => pure (f i x)) x = (pure (foldr n f x) : m α) := by + induction n generalizing x with + | zero => simp + | succ n ih => simp [foldrM_succ, foldr_succ, ih] + end Fin diff --git a/src/Init/Data/Fin/Lemmas.lean b/src/Init/Data/Fin/Lemmas.lean index 35b6d7e86a..f4c5c2304e 100644 --- a/src/Init/Data/Fin/Lemmas.lean +++ b/src/Init/Data/Fin/Lemmas.lean @@ -646,6 +646,20 @@ theorem rev_castSucc (k : Fin n) : rev (castSucc k) = succ (rev k) := k.rev_cast theorem rev_succ (k : Fin n) : rev (succ k) = castSucc (rev k) := k.rev_addNat 1 +@[simp, grind _=_] +theorem castSucc_succ (i : Fin n) : i.succ.castSucc = i.castSucc.succ := rfl + +@[simp, grind =] +theorem castLE_refl (h : n ≤ n) (i : Fin n) : i.castLE h = i := rfl + +@[simp, grind =] +theorem castSucc_castLE (h : n ≤ m) (i : Fin n) : + (i.castLE h).castSucc = i.castLE (by omega) := rfl + +@[simp, grind =] +theorem castSucc_natAdd (n : Nat) (i : Fin k) : + (i.natAdd n).castSucc = (i.castSucc).natAdd n := rfl + /-! ### pred -/ @[simp] theorem coe_pred (j : Fin (n + 1)) (h : j ≠ 0) : (j.pred h : Nat) = j - 1 := rfl diff --git a/src/Init/Data/List/FinRange.lean b/src/Init/Data/List/FinRange.lean index c900188eec..747132481c 100644 --- a/src/Init/Data/List/FinRange.lean +++ b/src/Init/Data/List/FinRange.lean @@ -6,7 +6,8 @@ Authors: François G. Dorais module prelude -import Init.Data.List.OfFn +import all Init.Data.List.OfFn +import Init.Data.List.Monadic set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables. set_option linter.indexVariables true -- Enforce naming conventions for index variables. @@ -57,3 +58,50 @@ theorem finRange_reverse {n} : (finRange n).reverse = (finRange n).map Fin.rev : simp [Fin.rev_succ] end List + +namespace Fin + +theorem foldlM_eq_finRange_foldlM [Monad m] (f : α → Fin n → m α) (x : α) : + foldlM n f x = (List.finRange n).foldlM f x := by + induction n generalizing x with + | zero => simp + | succ n ih => + simp [foldlM_succ, List.finRange_succ, List.foldlM_cons] + congr 1 + funext y + simp [ih, List.foldlM_map] + +theorem foldrM_eq_finRange_foldrM [Monad m] [LawfulMonad m] (f : Fin n → α → m α) (x : α) : + foldrM n f x = (List.finRange n).foldrM f x := by + induction n generalizing x with + | zero => simp + | succ n ih => + simp [foldrM_succ, List.finRange_succ, ih, List.foldrM_map] + +theorem foldl_eq_finRange_foldl (f : α → Fin n → α) (x : α) : + foldl n f x = (List.finRange n).foldl f x := by + induction n generalizing x with + | zero => simp + | succ n ih => + simp [foldl_succ, List.finRange_succ, ih, List.foldl_map] + +theorem foldr_eq_finRange_foldr (f : Fin n → α → α) (x : α) : + foldr n f x = (List.finRange n).foldr f x := by + induction n generalizing x with + | zero => simp + | succ n ih => + simp [foldr_succ, List.finRange_succ, ih, List.foldr_map] + +end Fin + +namespace List + +theorem ofFnM_succ {n} [Monad m] [LawfulMonad m] {f : Fin (n + 1) → m α} : + ofFnM f = (do + let a ← f 0 + let as ← ofFnM fun i => f i.succ + pure (a :: as)) := by + simp [ofFnM, Fin.foldlM_eq_finRange_foldlM, List.finRange_succ, List.foldlM_cons_eq_append, + List.foldlM_map] + +end List diff --git a/src/Init/Data/List/Lemmas.lean b/src/Init/Data/List/Lemmas.lean index 662866c7dc..cbdc03ea5b 100644 --- a/src/Init/Data/List/Lemmas.lean +++ b/src/Init/Data/List/Lemmas.lean @@ -2576,6 +2576,11 @@ theorem foldr_eq_foldrM {f : α → β → β} {b : β} {l : List α} : l.foldl (fun xs y => f y :: xs) l' = (l.map f).reverse ++ l' := by induction l generalizing l' <;> simp [*] +/-- Variant of `foldl_flip_cons_eq_append` specalized to `f = id`. -/ +@[simp, grind] theorem foldl_flip_cons_eq_append' {l l' : List α} : + l.foldl (fun xs y => y :: xs) l' = l.reverse ++ l' := by + induction l generalizing l' <;> simp [*] + @[simp, grind] theorem foldr_append_eq_append {l : List α} {f : α → List β} {l' : List β} : l.foldr (f · ++ ·) l' = (l.map f).flatten ++ l' := by induction l <;> simp [*] diff --git a/src/Init/Data/List/Monadic.lean b/src/Init/Data/List/Monadic.lean index a172ad11a1..376dd4f176 100644 --- a/src/Init/Data/List/Monadic.lean +++ b/src/Init/Data/List/Monadic.lean @@ -8,6 +8,8 @@ module prelude import Init.Data.List.TakeDrop import Init.Data.List.Attach +import Init.Data.List.OfFn +import Init.Data.Array.Bootstrap import all Init.Data.List.Control /-! @@ -69,13 +71,17 @@ theorem mapM'_eq_mapM [Monad m] [LawfulMonad m] {f : α → m β} {l : List α} @[simp] theorem mapM_id {l : List α} {f : α → Id β} : l.mapM f = l.map f := mapM_pure +@[simp] theorem mapM_map [Monad m] [LawfulMonad m] {f : α → β} {g : β → m γ} {l : List α} : + (l.map f).mapM g = l.mapM (g ∘ f) := by + induction l <;> simp_all + @[simp] theorem mapM_append [Monad m] [LawfulMonad m] {f : α → m β} {l₁ l₂ : List α} : (l₁ ++ l₂).mapM f = (return (← l₁.mapM f) ++ (← l₂.mapM f)) := by induction l₁ <;> simp [*] /-- Auxiliary lemma for `mapM_eq_reverse_foldlM_cons`. -/ theorem foldlM_cons_eq_append [Monad m] [LawfulMonad m] {f : α → m β} {as : List α} {b : β} {bs : List β} : - (as.foldlM (init := b :: bs) fun acc a => return ((← f a) :: acc)) = - (· ++ b :: bs) <$> as.foldlM (init := []) fun acc a => return ((← f a) :: acc) := by + (as.foldlM (init := b :: bs) fun acc a => (· :: acc) <$> f a) = + (· ++ b :: bs) <$> as.foldlM (init := []) fun acc a => (· :: acc) <$> f a := by induction as generalizing b bs with | nil => simp | cons a as ih => @@ -83,7 +89,7 @@ theorem foldlM_cons_eq_append [Monad m] [LawfulMonad m] {f : α → m β} {as : simp [ih, _root_.map_bind, Functor.map_map, Function.comp_def] theorem mapM_eq_reverse_foldlM_cons [Monad m] [LawfulMonad m] {f : α → m β} {l : List α} : - mapM f l = reverse <$> (l.foldlM (fun acc a => return ((← f a) :: acc)) []) := by + mapM f l = reverse <$> (l.foldlM (fun acc a => (· :: acc) <$> f a) []) := by rw [← mapM'_eq_mapM] induction l with | nil => simp diff --git a/src/Init/Data/List/OfFn.lean b/src/Init/Data/List/OfFn.lean index 4a1af6073d..d9b8749f11 100644 --- a/src/Init/Data/List/OfFn.lean +++ b/src/Init/Data/List/OfFn.lean @@ -27,6 +27,13 @@ Examples: -/ def ofFn {n} (f : Fin n → α) : List α := Fin.foldr n (f · :: ·) [] +/-- +Creates a list wrapped in a monad by applying the monadic function `f : Fin n → m α` +to each potential index in order, starting at `0`. +-/ +def ofFnM {n} [Monad m] (f : Fin n → m α) : m (List α) := + List.reverse <$> Fin.foldlM n (fun xs i => (· :: xs) <$> f i) [] + @[simp] theorem length_ofFn {f : Fin n → α} : (ofFn f).length = n := by simp only [ofFn] @@ -49,7 +56,8 @@ protected theorem getElem_ofFn {f : Fin n → α} (h : i < (ofFn f).length) : simp_all @[simp] -protected theorem getElem?_ofFn {f : Fin n → α} : (ofFn f)[i]? = if h : i < n then some (f ⟨i, h⟩) else none := +protected theorem getElem?_ofFn {f : Fin n → α} : + (ofFn f)[i]? = if h : i < n then some (f ⟨i, h⟩) else none := if h : i < (ofFn f).length then by rw [getElem?_eq_getElem h, List.getElem_ofFn] @@ -60,16 +68,31 @@ protected theorem getElem?_ofFn {f : Fin n → α} : (ofFn f)[i]? = if h : i < n /-- `ofFn` on an empty domain is the empty list. -/ @[simp] -theorem ofFn_zero {f : Fin 0 → α} : ofFn f = [] := - ext_get (by simp) (fun i hi₁ hi₂ => by contradiction) +theorem ofFn_zero {f : Fin 0 → α} : ofFn f = [] := by + rw [ofFn, Fin.foldr_zero] -@[simp] theorem ofFn_succ {n} {f : Fin (n + 1) → α} : ofFn f = f 0 :: ofFn fun i => f i.succ := ext_get (by simp) (fun i hi₁ hi₂ => by cases i · simp · simp) +theorem ofFn_succ_last {n} {f : Fin (n + 1) → α} : + ofFn f = (ofFn fun i => f i.castSucc) ++ [f (Fin.last n)] := by + induction n with + | zero => simp [ofFn_succ] + | succ n ih => + rw [ofFn_succ] + conv => rhs; rw [ofFn_succ] + rw [ih] + simp + +theorem ofFn_add {n m} {f : Fin (n + m) → α} : + ofFn f = (ofFn fun i => f (i.castLE (Nat.le_add_right n m))) ++ (ofFn fun i => f (i.natAdd n)) := by + induction m with + | zero => simp + | succ m ih => simp [ofFn_succ_last, ih] + @[simp] theorem ofFn_eq_nil_iff {f : Fin n → α} : ofFn f = [] ↔ n = 0 := by cases n <;> simp only [ofFn_zero, ofFn_succ, eq_self_iff_true, Nat.succ_ne_zero, reduceCtorEq] @@ -92,4 +115,66 @@ theorem getLast_ofFn {n} {f : Fin n → α} (h : ofFn f ≠ []) : (ofFn f).getLast h = f ⟨n - 1, Nat.sub_one_lt (mt ofFn_eq_nil_iff.2 h)⟩ := by simp [getLast_eq_getElem, length_ofFn, List.getElem_ofFn] +/-- `ofFnM` on an empty domain is the empty list. -/ +@[simp] +theorem ofFnM_zero [Monad m] [LawfulMonad m] {f : Fin 0 → m α} : ofFnM f = pure [] := by + simp [ofFnM] + +/-! See `Init.Data.List.FinRange` for the `ofFnM_succ` variant. -/ + +theorem ofFnM_succ_last {n} [Monad m] [LawfulMonad m] {f : Fin (n + 1) → m α} : + ofFnM f = (do + let as ← ofFnM fun i => f i.castSucc + let a ← f (Fin.last n) + pure (as ++ [a])) := by + simp [ofFnM, Fin.foldlM_succ_last] + +theorem ofFnM_add {n m} [Monad m] [LawfulMonad m] {f : Fin (n + k) → m α} : + ofFnM f = (do + let as ← ofFnM fun i : Fin n => f (i.castLE (Nat.le_add_right n k)) + let bs ← ofFnM fun i : Fin k => f (i.natAdd n) + pure (as ++ bs)) := by + induction k with + | zero => simp + | succ k ih => simp [ofFnM_succ_last, ih] + + +end List + +namespace Fin + +theorem foldl_cons_eq_append {f : Fin n → α} {xs : List α} : + Fin.foldl n (fun xs i => f i :: xs) xs = (List.ofFn f).reverse ++ xs := by + induction n generalizing xs with + | zero => simp + | succ n ih => simp [Fin.foldl_succ, List.ofFn_succ, ih] + +theorem foldr_cons_eq_append {f : Fin n → α} {xs : List α} : + Fin.foldr n (fun i xs => f i :: xs) xs = List.ofFn f ++ xs:= by + induction n generalizing xs with + | zero => simp + | succ n ih => simp [Fin.foldr_succ, List.ofFn_succ, ih] + +end Fin + +namespace List + +@[simp] +theorem ofFnM_pure_comp [Monad m] [LawfulMonad m] {n} {f : Fin n → α} : + ofFnM (pure ∘ f) = (pure (ofFn f) : m (List α)) := by + simp [ofFnM, Fin.foldlM_pure, Fin.foldl_cons_eq_append] + +-- Variant of `ofFnM_pure_comp` using a lambda. +-- This is not marked a `@[simp]` as it would match on every occurrence of `ofFnM`. +theorem ofFnM_pure [Monad m] [LawfulMonad m] {n} {f : Fin n → α} : + ofFnM (fun i => pure (f i)) = (pure (ofFn f) : m (List α)) := + ofFnM_pure_comp + +@[simp, grind =] theorem idRun_ofFnM {f : Fin n → Id α} : + Id.run (ofFnM f) = ofFn (fun i => Id.run (f i)) := by + unfold Id.run + induction n with + | zero => simp + | succ n ih => simp [ofFnM_succ_last, ofFn_succ_last, ih] + end List diff --git a/src/Init/Data/List/ToArray.lean b/src/Init/Data/List/ToArray.lean index 70aa5ebb7b..f558f3b025 100644 --- a/src/Init/Data/List/ToArray.lean +++ b/src/Init/Data/List/ToArray.lean @@ -210,12 +210,6 @@ theorem forM_toArray [Monad m] (l : List α) (f : α → m PUnit) : cases as simp -@[simp] theorem foldl_push {l : List α} {as : Array α} : l.foldl Array.push as = as ++ l.toArray := by - induction l generalizing as <;> simp [*] - -@[simp] theorem foldr_push {l : List α} {as : Array α} : l.foldr (fun a bs => push bs a) as = as ++ l.reverse.toArray := by - rw [foldr_eq_foldl_reverse, foldl_push] - @[simp, grind =] theorem findSomeM?_toArray [Monad m] [LawfulMonad m] (f : α → m (Option β)) (l : List α) : l.toArray.findSomeM? f = l.findSomeM? f := by rw [Array.findSomeM?] diff --git a/src/Init/Data/Nat/Fold.lean b/src/Init/Data/Nat/Fold.lean index f1dd702ec7..37f3546549 100644 --- a/src/Init/Data/Nat/Fold.lean +++ b/src/Init/Data/Nat/Fold.lean @@ -197,6 +197,8 @@ theorem allTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bo omega go n 0 f +/-! ### `fold` -/ + @[simp] theorem fold_zero {α : Type u} (f : (i : Nat) → i < 0 → α → α) (init : α) : fold 0 f init = init := by simp [fold] @@ -210,6 +212,8 @@ theorem fold_eq_finRange_foldl {α : Type u} (n : Nat) (f : (i : Nat) → i < n | succ n ih => simp [ih, List.finRange_succ_last, List.foldl_map] +/-! ### `foldRev` -/ + @[simp] theorem foldRev_zero {α : Type u} (f : (i : Nat) → i < 0 → α → α) (init : α) : foldRev 0 f init = init := by simp [foldRev] @@ -223,10 +227,12 @@ theorem foldRev_eq_finRange_foldr {α : Type u} (n : Nat) (f : (i : Nat) → i < | zero => simp | succ n ih => simp [ih, List.finRange_succ_last, List.foldr_map] +/-! ### `any` -/ + @[simp] theorem any_zero {f : (i : Nat) → i < 0 → Bool} : any 0 f = false := by simp [any] @[simp] theorem any_succ {n : Nat} (f : (i : Nat) → i < n + 1 → Bool) : - any (n + 1) f = (any n (fun i h => f i (by omega)) || f n (by omega)) := by simp [any] + any (n + 1) f = (any n (fun i h => f i (by omega)) || f n (by omega)) := by simp [any] theorem any_eq_finRange_any {n : Nat} (f : (i : Nat) → i < n → Bool) : any n f = (List.finRange n).any (fun ⟨i, h⟩ => f i h) := by @@ -234,10 +240,12 @@ theorem any_eq_finRange_any {n : Nat} (f : (i : Nat) → i < n → Bool) : | zero => simp | succ n ih => simp [ih, List.finRange_succ_last, List.any_map, Function.comp_def] +/-! ### `all` -/ + @[simp] theorem all_zero {f : (i : Nat) → i < 0 → Bool} : all 0 f = true := by simp [all] @[simp] theorem all_succ {n : Nat} (f : (i : Nat) → i < n + 1 → Bool) : - all (n + 1) f = (all n (fun i h => f i (by omega)) && f n (by omega)) := by simp [all] + all (n + 1) f = (all n (fun i h => f i (by omega)) && f n (by omega)) := by simp [all] theorem all_eq_finRange_all {n : Nat} (f : (i : Nat) → i < n → Bool) : all n f = (List.finRange n).all (fun ⟨i, h⟩ => f i h) := by @@ -250,7 +258,7 @@ end Nat namespace Prod /-- -Combines an initial value with each natural number from in a range, in increasing order. +Combines an initial value with each natural number from a range, in increasing order. In particular, `(start, stop).foldI f init` applies `f`on all the numbers from `start` (inclusive) to `stop` (exclusive) in increasing order: @@ -260,7 +268,7 @@ Examples: * `(5, 8).foldI (fun j _ _ xs => xs.push j) #[] = #[5, 6, 7]` * `(5, 8).foldI (fun j _ _ xs => toString j :: xs) [] = ["7", "6", "5"]` -/ -@[inline] def foldI {α : Type u} (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → α → α) (init : α) : α := +@[inline, simp] def foldI {α : Type u} (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → α → α) (init : α) : α := (i.2 - i.1).fold (fun j _ => f (i.1 + j) (by omega) (by omega)) init /-- @@ -274,7 +282,7 @@ Examples: * `(5, 8).anyI (fun j _ _ => j % 2 = 0) = true` * `(6, 6).anyI (fun j _ _ => j % 2 = 0) = false` -/ -@[inline] def anyI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool := +@[inline, simp] def anyI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool := (i.2 - i.1).any (fun j _ => f (i.1 + j) (by omega) (by omega)) /-- @@ -288,7 +296,7 @@ Examples: * `(5, 8).allI (fun j _ _ => j % 2 = 0) = false` * `(6, 7).allI (fun j _ _ => j % 2 = 0) = true` -/ -@[inline] def allI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool := +@[inline, simp] def allI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool := (i.2 - i.1).all (fun j _ => f (i.1 + j) (by omega) (by omega)) end Prod diff --git a/src/Init/Data/Vector/Basic.lean b/src/Init/Data/Vector/Basic.lean index 68b4454ee5..8526e8cae7 100644 --- a/src/Init/Data/Vector/Basic.lean +++ b/src/Init/Data/Vector/Basic.lean @@ -307,6 +307,8 @@ abbrev zipWithIndex := @zipIdx @[inline] def ofFn (f : Fin n → α) : Vector α n := ⟨Array.ofFn f, by simp⟩ +/-! See also `Vector.ofFnM` defined in `Init.Data.Vector.OfFn`. -/ + /-- Swap two elements of a vector using `Fin` indices. diff --git a/src/Init/Data/Vector/Lemmas.lean b/src/Init/Data/Vector/Lemmas.lean index 4acb3c832d..812d5d69d8 100644 --- a/src/Init/Data/Vector/Lemmas.lean +++ b/src/Init/Data/Vector/Lemmas.lean @@ -53,9 +53,9 @@ theorem toArray_mk {xs : Array α} (h : xs.size = n) : (Vector.mk xs h).toArray (Vector.mk xs size).contains a = xs.contains a := by simp [contains] -@[simp] theorem push_mk {xs : Array α} {size : xs.size = n} {x : α} : - (Vector.mk xs size).push x = - Vector.mk (xs.push x) (by simp [size, Nat.succ_eq_add_one]) := rfl +@[simp] theorem push_mk {xs : Array α} {size : xs.size = n} : + (Vector.mk xs size).push = + fun x => Vector.mk (xs.push x) (by simp [size, Nat.succ_eq_add_one]) := rfl @[simp] theorem pop_mk {xs : Array α} {size : xs.size = n} : (Vector.mk xs size).pop = Vector.mk xs.pop (by simp [size]) := rfl @@ -1660,12 +1660,12 @@ theorem forall_mem_append {p : α → Prop} {xs : Vector α n} {ys : Vector α m (∀ (x) (_ : x ∈ xs ++ ys), p x) ↔ (∀ (x) (_ : x ∈ xs), p x) ∧ (∀ (x) (_ : x ∈ ys), p x) := by simp only [mem_append, or_imp, forall_and] -@[grind] +@[simp, grind] theorem empty_append {xs : Vector α n} : (#v[] : Vector α 0) ++ xs = xs.cast (by omega) := by rcases xs with ⟨as, rfl⟩ simp -@[grind] +@[simp, grind] theorem append_empty {xs : Vector α n} : xs ++ (#v[] : Vector α 0) = xs := by rw [← toArray_inj, toArray_append, Array.append_empty] diff --git a/src/Init/Data/Vector/Monadic.lean b/src/Init/Data/Vector/Monadic.lean index 24df9338dd..cf30e5149c 100644 --- a/src/Init/Data/Vector/Monadic.lean +++ b/src/Init/Data/Vector/Monadic.lean @@ -38,6 +38,11 @@ theorem mapM_pure [Monad m] [LawfulMonad m] {xs : Vector α n} (f : α → β) : apply map_toArray_inj.mp simp +@[simp] theorem mapM_map [Monad m] [LawfulMonad m] {f : α → β} {g : β → m γ} {xs : Vector α n} : + (xs.map f).mapM g = xs.mapM (g ∘ f) := by + apply map_toArray_inj.mp + simp + @[congr] theorem mapM_congr [Monad m] {xs ys : Vector α n} (w : xs = ys) {f : α → m β} : xs.mapM f = ys.mapM f := by diff --git a/src/Init/Data/Vector/OfFn.lean b/src/Init/Data/Vector/OfFn.lean index ed8d96f957..87002994b5 100644 --- a/src/Init/Data/Vector/OfFn.lean +++ b/src/Init/Data/Vector/OfFn.lean @@ -8,6 +8,7 @@ module prelude import all Init.Data.Vector.Basic import Init.Data.Vector.Lemmas +import Init.Data.Vector.Monadic import Init.Data.Array.OfFn /-! @@ -40,4 +41,122 @@ theorem back_ofFn {n} [NeZero n] {f : Fin n → α} : (ofFn f).back = f ⟨n - 1, by have := NeZero.ne n; omega⟩ := by simp [back] +theorem ofFn_succ {f : Fin (n+1) → α} : + ofFn f = (ofFn (fun (i : Fin n) => f i.castSucc)).push (f ⟨n, by omega⟩) := by + ext i h + · simp only [getElem_ofFn, getElem_push, Fin.castSucc_mk, left_eq_dite_iff] + intro h' + have : i = n := by omega + simp_all + +theorem ofFn_add {n m} {f : Fin (n + m) → α} : + ofFn f = (ofFn (fun i => f (i.castLE (Nat.le_add_right n m)))) ++ (ofFn (fun i => f (i.natAdd n))) := by + apply Vector.toArray_inj.mp + simp [Array.ofFn_add] + +theorem ofFn_succ' {f : Fin (n+1) → α} : + ofFn f = (#v[f 0] ++ ofFn (fun i => f i.succ)).cast (by omega) := by + apply Vector.toArray_inj.mp + simp [Array.ofFn_succ'] + +/-! ### ofFnM -/ + +/-- Construct (in a monadic context) a vector by applying a monadic function to each index. -/ +def ofFnM {n} [Monad m] (f : Fin n → m α) : m (Vector α n) := + go 0 (by omega) (Array.emptyWithCapacity n) rfl where + /-- Auxiliary for `ofFn`. `ofFn.go f i acc = acc ++ #v[f i, ..., f(n - 1)]` -/ + go (i : Nat) (h' : i ≤ n) (acc : Array α) (w : acc.size = i) : m (Vector α n) := do + if h : i < n then + go (i+1) (by omega) (acc.push (← f ⟨i, h⟩)) (by simp [w]) + else + pure ⟨acc, by omega⟩ + +@[simp] +theorem ofFnM_zero [Monad m] {f : Fin 0 → m α} : Vector.ofFnM f = pure #v[] := by + simp [ofFnM, ofFnM.go] + +private theorem ofFnM_go_succ {n} [Monad m] [LawfulMonad m] {f : Fin (n + 1) → m α} + (hi : i ≤ n := by omega) {h : xs.size = i} : + ofFnM.go f i (by omega) xs h = (do + let as ← ofFnM.go (fun i => f i.castSucc) i hi xs h + let a ← f (Fin.last n) + pure (as.push a)) := by + fun_induction ofFnM.go f i (by omega) xs h + case case1 acc h' h ih => + if h : acc.size = n then + unfold ofFnM.go + rw [dif_neg (by omega)] + have h : ¬ acc.size + 1 < n + 1 := by omega + have : Fin.last n = ⟨acc.size, by omega⟩ := by ext; simp; omega + simp [*] + else + have : acc.size + 1 ≤ n := by omega + simp only [ih, this] + conv => rhs; unfold ofFnM.go + rw [dif_pos (by omega)] + simp + case case2 => + omega + +theorem ofFnM_succ {n} [Monad m] [LawfulMonad m] {f : Fin (n + 1) → m α} : + ofFnM f = (do + let as ← ofFnM fun i => f i.castSucc + let a ← f (Fin.last n) + pure (as.push a)) := by + simp [ofFnM, ofFnM_go_succ] + +theorem ofFnM_add {n m} [Monad m] [LawfulMonad m] {f : Fin (n + k) → m α} : + ofFnM f = (do + let as ← ofFnM (fun i => f (i.castLE (Nat.le_add_right n k))) + let bs ← ofFnM (fun i => f (i.natAdd n)) + pure (as ++ bs)) := by + induction k with + | zero => simp + | succ k ih => simp [ofFnM_succ, ih, ← push_append] + +@[simp, grind] theorem toArray_ofFnM [Monad m] [LawfulMonad m] {f : Fin n → m α} : + toArray <$> ofFnM f = Array.ofFnM f := by + induction n with + | zero => simp + | succ n ih => simp [ofFnM_succ, Array.ofFnM_succ, ← ih] + +@[simp, grind] theorem toList_ofFnM [Monad m] [LawfulMonad m] {f : Fin n → m α} : + toList <$> Vector.ofFnM f = List.ofFnM f := by + unfold toList + suffices Array.toList <$> (toArray <$> ofFnM f) = List.ofFnM f by + simpa [-toArray_ofFnM] + simp + +theorem ofFnM_succ' {n} [Monad m] [LawfulMonad m] {f : Fin (n + 1) → m α} : + ofFnM f = (do + let a ← f 0 + let as ← ofFnM fun i => f i.succ + pure ((#v[a] ++ as).cast (by omega))) := by + apply Vector.map_toArray_inj.mp + simp only [toArray_ofFnM, Array.ofFnM_succ', bind_pure_comp, map_bind, Functor.map_map, + toArray_cast, toArray_append] + congr 1 + funext x + have : (fun xs : Vector α n => #[x] ++ xs.toArray) = (#[x] ++ ·) ∘ toArray := by funext xs; simp + simp [this, comp_map] + +@[simp] +theorem ofFnM_pure_comp [Monad m] [LawfulMonad m] {n} {f : Fin n → α} : + ofFnM (pure ∘ f) = (pure (ofFn f) : m (Vector α n)) := by + apply Vector.map_toArray_inj.mp + simp + +-- Variant of `ofFnM_pure_comp` using a lambda. +-- This is not marked a `@[simp]` as it would match on every occurrence of `ofFnM`. +theorem ofFnM_pure [Monad m] [LawfulMonad m] {n} {f : Fin n → α} : + ofFnM (fun i => pure (f i)) = (pure (ofFn f) : m (Vector α n)) := + ofFnM_pure_comp + +@[simp, grind =] theorem idRun_ofFnM {f : Fin n → Id α} : + Id.run (ofFnM f) = ofFn (fun i => Id.run (f i)) := by + unfold Id.run + induction n with + | zero => simp + | succ n ih => simp [ofFnM_succ', ofFn_succ', ih] + end Vector