diff --git a/src/Init/Control/Basic.lean b/src/Init/Control/Basic.lean index 61a73e9228..219130d1e1 100644 --- a/src/Init/Control/Basic.lean +++ b/src/Init/Control/Basic.lean @@ -8,6 +8,28 @@ import Init.Core universe u v w +/-- +A `ForIn'` instance, which handles `for h : x in c do`, +can also handle `for x in x do` by ignoring `h`, and so provides a `ForIn` instance. +-/ +instance (priority := low) instForInOfForIn' [ForIn' m ρ α d] : ForIn m ρ α where + forIn x b f := forIn' x b fun a _ => f a + +@[simp] theorem forIn'_eq_forIn [d : Membership α ρ] [ForIn' m ρ α d] {β} [Monad m] (x : ρ) (b : β) + (f : (a : α) → a ∈ x → β → m (ForInStep β)) (g : (a : α) → β → m (ForInStep β)) + (h : ∀ a m b, f a m b = g a b) : + forIn' x b f = forIn x b g := by + simp [instForInOfForIn'] + congr + apply funext + intro a + apply funext + intro m + apply funext + intro b + simp [h] + rfl + @[reducible] def Functor.mapRev {f : Type u → Type v} [Functor f] {α β : Type u} : f α → (α → β) → f β := fun a f => f <$> a diff --git a/src/Init/Core.lean b/src/Init/Core.lean index b98fb6afdf..46b3644d19 100644 --- a/src/Init/Core.lean +++ b/src/Init/Core.lean @@ -324,7 +324,6 @@ class ForIn' (m : Type u₁ → Type u₂) (ρ : Type u) (α : outParam (Type v) export ForIn' (forIn') - /-- Auxiliary type used to compile `do` notation. It is used when compiling a do block nested inside a combinator like `tryCatch`. It encodes the possible ways the diff --git a/src/Init/Data/Array/Basic.lean b/src/Init/Data/Array/Basic.lean index 88117d0654..fec1b7bfee 100644 --- a/src/Init/Data/Array/Basic.lean +++ b/src/Init/Data/Array/Basic.lean @@ -82,6 +82,22 @@ theorem ext' {as bs : Array α} (h : as.toList = bs.toList) : as = bs := by @[simp] theorem getElem_toList {a : Array α} {i : Nat} (h : i < a.size) : a.toList[i] = a[i] := rfl +/-- `a ∈ as` is a predicate which asserts that `a` is in the array `as`. -/ +-- NB: This is defined as a structure rather than a plain def so that a lemma +-- like `sizeOf_lt_of_mem` will not apply with no actual arrays around. +structure Mem (as : Array α) (a : α) : Prop where + val : a ∈ as.toList + +instance : Membership α (Array α) where + mem := Mem + +theorem mem_def {a : α} {as : Array α} : a ∈ as ↔ a ∈ as.toList := + ⟨fun | .mk h => h, Array.Mem.mk⟩ + +@[simp] theorem getElem_mem {l : Array α} {i : Nat} (h : i < l.size) : l[i] ∈ l := by + rw [Array.mem_def, ← getElem_toList] + apply List.getElem_mem + end Array namespace List @@ -316,6 +332,37 @@ protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m instance : ForIn m (Array α) α where forIn := Array.forIn +/-- See comment at `forInUnsafe` -/ +@[inline] unsafe def forIn'Unsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : (a : α) → a ∈ as → β → m (ForInStep β)) : m β := + let sz := as.usize + let rec @[specialize] loop (i : USize) (b : β) : m β := do + if i < sz then + let a := as.uget i lcProof + match (← f a lcProof b) with + | ForInStep.done b => pure b + | ForInStep.yield b => loop (i+1) b + else + pure b + loop 0 b + +/-- Reference implementation for `forIn'` -/ +@[implemented_by Array.forIn'Unsafe] +protected def forIn' {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : (a : α) → a ∈ as → β → m (ForInStep β)) : m β := + let rec loop (i : Nat) (h : i ≤ as.size) (b : β) : m β := do + match i, h with + | 0, _ => pure b + | i+1, h => + have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h + have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide) + have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this + match (← f as[as.size - 1 - i] (getElem_mem this) b) with + | ForInStep.done b => pure b + | ForInStep.yield b => loop i (Nat.le_of_lt h') b + loop as.size (Nat.le_refl _) b + +instance : ForIn' m (Array α) α inferInstance where + forIn' := Array.forIn' + /-- See comment at `forInUnsafe` -/ @[inline] unsafe def foldlMUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : β → α → m β) (init : β) (as : Array α) (start := 0) (stop := as.size) : m β := diff --git a/src/Init/Data/Array/Lemmas.lean b/src/Init/Data/Array/Lemmas.lean index 3ba16ef081..474491cb3b 100644 --- a/src/Init/Data/Array/Lemmas.lean +++ b/src/Init/Data/Array/Lemmas.lean @@ -21,8 +21,7 @@ namespace Array @[simp] theorem getElem_mk {xs : List α} {i : Nat} (h : i < xs.length) : (Array.mk xs)[i] = xs[i] := rfl -theorem getElem_eq_getElem_toList {a : Array α} (h : i < a.size) : a[i] = a.toList[i] := by - by_cases i < a.size <;> (try simp [*]) <;> rfl +theorem getElem_eq_getElem_toList {a : Array α} (h : i < a.size) : a[i] = a.toList[i] := rfl theorem getElem?_eq_getElem {a : Array α} {i : Nat} (h : i < a.size) : a[i]? = some a[i] := getElem?_pos .. @@ -85,6 +84,9 @@ We prefer to pull `List.toArray` outwards. (a.toArrayAux b).size = b.size + a.length := by simp [size] +@[simp] theorem mem_toArray {a : α} {l : List α} : a ∈ l.toArray ↔ a ∈ l := by + simp [mem_def] + @[simp] theorem push_toArray (l : List α) (a : α) : l.toArray.push a = (l ++ [a]).toArray := by apply ext' simp @@ -121,6 +123,30 @@ We prefer to pull `List.toArray` outwards. rw [Array.forIn, forIn_loop_toArray] simp +@[simp] theorem forIn'_loop_toArray [Monad m] (l : List α) (f : (a : α) → a ∈ l.toArray → β → m (ForInStep β)) (i : Nat) + (h : i ≤ l.length) (b : β) : + Array.forIn'.loop l.toArray f i h b = + forIn' (l.drop (l.length - i)) b (fun a m b => f a (by simpa using mem_of_mem_drop m) b) := by + induction i generalizing l b with + | zero => + simp [Array.forIn'.loop] + | succ i ih => + simp only [Array.forIn'.loop, size_toArray, getElem_toArray, ih, forIn_eq_forIn] + have t : drop (l.length - (i + 1)) l = l[l.length - i - 1] :: drop (l.length - i) l := by + simp only [Nat.sub_add_eq] + rw [List.drop_sub_one (by omega), List.getElem?_eq_getElem (by omega)] + simp only [Option.toList_some, singleton_append] + simp [t] + have t : l.length - 1 - i = l.length - i - 1 := by omega + simp only [t] + congr + +@[simp] theorem forIn'_toArray [Monad m] (l : List α) (b : β) (f : (a : α) → a ∈ l.toArray → β → m (ForInStep β)) : + forIn' l.toArray b f = forIn' l b (fun a m b => f a (mem_toArray.mpr m) b) := by + change Array.forIn' _ _ _ = List.forIn' _ _ _ + rw [Array.forIn', forIn'_loop_toArray] + simp [List.forIn_eq_forIn] + theorem foldrM_toArray [Monad m] (f : α → β → m β) (init : β) (l : List α) : l.toArray.foldrM f init = l.foldrM f init := by rw [foldrM_eq_reverse_foldlM_toList] @@ -268,9 +294,6 @@ theorem anyM_stop_le_start [Monad m] (p : α → m Bool) (as : Array α) (start (h : min stop as.size ≤ start) : anyM p as start stop = pure false := by rw [anyM_eq_anyM_loop, anyM.loop, dif_neg (Nat.not_lt.2 h)] -theorem mem_def {a : α} {as : Array α} : a ∈ as ↔ a ∈ as.toList := - ⟨fun | .mk h => h, Array.Mem.mk⟩ - @[simp] theorem not_mem_empty (a : α) : ¬(a ∈ #[]) := by simp [mem_def] @@ -460,10 +483,6 @@ theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size} idx < a.size := hidx -@[simp] theorem getElem_mem {l : Array α} {i : Nat} (h : i < l.size) : l[i] ∈ l := by - erw [Array.mem_def, getElem_eq_getElem_toList] - apply List.get_mem - theorem getElem_fin_eq_getElem_toList (a : Array α) (i : Fin a.size) : a[i] = a.toList[i] := rfl @[simp] theorem ugetElem_eq_getElem (a : Array α) {i : USize} (h : i.toNat < a.size) : @@ -728,6 +747,11 @@ theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Arra cases as simp +@[simp] theorem forIn'_toList [Monad m] (as : Array α) (b : β) (f : (a : α) → a ∈ as.toList → β → m (ForInStep β)) : + forIn' as.toList b f = forIn' as b (fun a m b => f a (mem_toList.mpr m) b) := by + cases as + simp + /-! ### foldl / foldr -/ @[simp] theorem foldlM_loop_empty [Monad m] (f : β → α → m β) (init : β) (i j : Nat) : @@ -1411,9 +1435,6 @@ namespace List Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible. -/ -@[simp] theorem mem_toArray {a : α} {l : List α} : a ∈ l.toArray ↔ a ∈ l := by - simp [mem_def] - @[simp] theorem toListRev_toArray (l : List α) : l.toArray.toListRev = l.reverse := by simp diff --git a/src/Init/Data/Array/Mem.lean b/src/Init/Data/Array/Mem.lean index c33251ba39..5887f4bb7f 100644 --- a/src/Init/Data/Array/Mem.lean +++ b/src/Init/Data/Array/Mem.lean @@ -10,15 +10,6 @@ import Init.Data.List.BasicAux namespace Array -/-- `a ∈ as` is a predicate which asserts that `a` is in the array `as`. -/ --- NB: This is defined as a structure rather than a plain def so that a lemma --- like `sizeOf_lt_of_mem` will not apply with no actual arrays around. -structure Mem (as : Array α) (a : α) : Prop where - val : a ∈ as.toList - -instance : Membership α (Array α) where - mem := Mem - theorem sizeOf_lt_of_mem [SizeOf α] {as : Array α} (h : a ∈ as) : sizeOf a < sizeOf as := by cases as with | _ as => exact Nat.lt_trans (List.sizeOf_lt_of_mem h.val) (by simp_arith) diff --git a/src/Init/Data/List/Control.lean b/src/Init/Data/List/Control.lean index 7e6dc918c6..2f945829ac 100644 --- a/src/Init/Data/List/Control.lean +++ b/src/Init/Data/List/Control.lean @@ -254,6 +254,8 @@ instance : ForIn m (List α) α where instance : ForIn' m (List α) α inferInstance where forIn' := List.forIn' +@[simp] theorem forIn'_eq_forIn' [Monad m] : @List.forIn' α β m _ = forIn' := rfl + @[simp] theorem forIn'_eq_forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : List α) (init : β) (f : α → β → m (ForInStep β)) : forIn' as init (fun a _ b => f a b) = forIn as init f := by simp [forIn', forIn, List.forIn, List.forIn'] have : ∀ cs h, List.forIn'.loop cs (fun a _ b => f a b) as init h = List.forIn.loop f as init := by diff --git a/src/Init/Data/List/Lemmas.lean b/src/Init/Data/List/Lemmas.lean index b89341f4cc..1112ba8f65 100644 --- a/src/Init/Data/List/Lemmas.lean +++ b/src/Init/Data/List/Lemmas.lean @@ -492,10 +492,6 @@ theorem getElem?_of_mem {a} {l : List α} (h : a ∈ l) : ∃ n : Nat, l[n]? = s theorem get?_of_mem {a} {l : List α} (h : a ∈ l) : ∃ n, l.get? n = some a := let ⟨⟨n, _⟩, e⟩ := get_of_mem h; ⟨n, e ▸ get?_eq_get _⟩ -@[simp] theorem getElem_mem : ∀ {l : List α} {n} (h : n < l.length), l[n]'h ∈ l - | _ :: _, 0, _ => .head .. - | _ :: l, _+1, _ => .tail _ (getElem_mem (l := l) ..) - theorem get_mem : ∀ (l : List α) n h, get l ⟨n, h⟩ ∈ l | _ :: _, 0, _ => .head .. | _ :: l, _+1, _ => .tail _ (get_mem l ..) diff --git a/src/Init/Data/List/Monadic.lean b/src/Init/Data/List/Monadic.lean index 04dcfda7d8..f8ded7d6f2 100644 --- a/src/Init/Data/List/Monadic.lean +++ b/src/Init/Data/List/Monadic.lean @@ -87,6 +87,68 @@ theorem mapM_eq_reverse_foldlM_cons [Monad m] [LawfulMonad m] (f : α → m β) (l₁ ++ l₂).forM f = (do l₁.forM f; l₂.forM f) := by induction l₁ <;> simp [*] +/-! ### forIn' -/ + +@[simp] theorem forIn'_nil [Monad m] (f : (a : α) → a ∈ [] → β → m (ForInStep β)) (b : β) : forIn' [] b f = pure b := + rfl + +theorem forIn'_loop_congr [Monad m] {as bs : List α} + {f : (a' : α) → a' ∈ as → β → m (ForInStep β)} + {g : (a' : α) → a' ∈ bs → β → m (ForInStep β)} + {b : β} (ha : ∃ ys, ys ++ xs = as) (hb : ∃ ys, ys ++ xs = bs) + (h : ∀ a m m' b, f a m b = g a m' b) : forIn'.loop as f xs b ha = forIn'.loop bs g xs b hb := by + induction xs generalizing b with + | nil => simp [forIn'.loop] + | cons a xs ih => + simp only [forIn'.loop] at * + congr 1 + · rw [h] + · funext s + obtain b | b := s + · rfl + · simp + rw [ih] + +@[simp] theorem forIn'_cons [Monad m] {a : α} {as : List α} + (f : (a' : α) → a' ∈ a :: as → β → m (ForInStep β)) (b : β) : + forIn' (a::as) b f = f a (mem_cons_self a as) b >>= + fun | ForInStep.done b => pure b | ForInStep.yield b => forIn' as b fun a' m b => f a' (mem_cons_of_mem a m) b := by + simp only [forIn', List.forIn', forIn'.loop] + congr 1 + funext s + obtain b | b := s + · rfl + · apply forIn'_loop_congr + intros + rfl + +@[congr] theorem forIn'_congr [Monad m] {as bs : List α} (w : as = bs) + {b b' : β} (hb : b = b') + {f : (a' : α) → a' ∈ as → β → m (ForInStep β)} + {g : (a' : α) → a' ∈ bs → β → m (ForInStep β)} + (h : ∀ a m b, f a (by simpa [w] using m) b = g a m b) : + forIn' as b f = forIn' bs b' g := by + induction bs generalizing as b b' with + | nil => + subst w + simp [hb, forIn'_nil] + | cons b bs ih => + cases as with + | nil => simp at w + | cons a as => + simp only [cons.injEq] at w + obtain ⟨rfl, rfl⟩ := w + simp only [forIn'_cons] + congr 1 + · simp [h, hb] + · funext s + obtain b | b := s + · rfl + · simp + rw [ih rfl rfl] + intro a m b + exact h a (mem_cons_of_mem _ m) b + /-! ### allM -/ theorem allM_eq_not_anyM_not [Monad m] [LawfulMonad m] (p : α → m Bool) (as : List α) : diff --git a/src/Init/Data/Option/List.lean b/src/Init/Data/Option/List.lean index aa22d3912d..794e9f1e9d 100644 --- a/src/Init/Data/Option/List.lean +++ b/src/Init/Data/Option/List.lean @@ -11,4 +11,28 @@ namespace Option @[simp] theorem mem_toList {a : α} {o : Option α} : a ∈ o.toList ↔ a ∈ o := by cases o <;> simp [eq_comm] +@[simp] theorem forIn'_none [Monad m] (b : β) (f : (a : α) → a ∈ none → β → m (ForInStep β)) : + forIn' none b f = pure b := by + rfl + +@[simp] theorem forIn'_some [Monad m] (a : α) (b : β) (f : (a' : α) → a' ∈ some a → β → m (ForInStep β)) : + forIn' (some a) b f = bind (f a rfl b) (fun | .done r | .yield r => pure r) := by + rfl + +@[simp] theorem forIn_none [Monad m] (b : β) (f : α → β → m (ForInStep β)) : + forIn none b f = pure b := by + rfl + +@[simp] theorem forIn_some [Monad m] (a : α) (b : β) (f : α → β → m (ForInStep β)) : + forIn (some a) b f = bind (f a b) (fun | .done r | .yield r => pure r) := by + rfl + +@[simp] theorem forIn'_toList [Monad m] (o : Option α) (b : β) (f : (a : α) → a ∈ o.toList → β → m (ForInStep β)) : + forIn' o.toList b f = forIn' o b fun a m b => f a (by simpa using m) b := by + cases o <;> rfl + +@[simp] theorem forIn_toList [Monad m] (o : Option α) (b : β) (f : α → β → m (ForInStep β)) : + forIn o.toList b f = forIn o b f := by + cases o <;> rfl + end Option diff --git a/src/Init/GetElem.lean b/src/Init/GetElem.lean index 0553cee013..2906fd3204 100644 --- a/src/Init/GetElem.lean +++ b/src/Init/GetElem.lean @@ -207,6 +207,10 @@ instance : GetElem (List α) Nat α fun as i => i < as.length where @[deprecated (since := "2024-06-12")] abbrev cons_getElem_succ := @getElem_cons_succ +@[simp] theorem getElem_mem : ∀ {l : List α} {n} (h : n < l.length), l[n]'h ∈ l + | _ :: _, 0, _ => .head .. + | _ :: l, _+1, _ => .tail _ (getElem_mem (l := l) ..) + theorem get_drop_eq_drop (as : List α) (i : Nat) (h : i < as.length) : as[i] :: as.drop (i+1) = as.drop i := match as, i with | _::_, 0 => rfl diff --git a/tests/lean/run/array_simp.lean b/tests/lean/run/array_simp.lean index 39de50da0d..d0ed4dd48d 100644 --- a/tests/lean/run/array_simp.lean +++ b/tests/lean/run/array_simp.lean @@ -13,3 +13,11 @@ attribute [local simp] Id.run in for i in [1,2,3,4].toArray do s := s + i pure s) ~> 10 + +attribute [local simp] Id.run in +#check_simp + (Id.run do + let mut s := 0 + for h : i in [1,2,3,4].toArray do + s := s + i + pure s) ~> 10 diff --git a/tests/lean/run/treeNode.lean b/tests/lean/run/treeNode.lean index 40f7a318fc..12958ee83c 100644 --- a/tests/lean/run/treeNode.lean +++ b/tests/lean/run/treeNode.lean @@ -14,7 +14,7 @@ def treeToList (t : TreeNode) : List String := return r @[simp] theorem treeToList_eq (name : String) (children : List TreeNode) : treeToList (.mkNode name children) = name :: List.join (children.map treeToList) := by - simp [treeToList, Id.run, forIn, List.forIn] + simp only [treeToList, Id.run, Id.pure_eq, Id.bind_eq, List.forIn'_eq_forIn, forIn, List.forIn] have : ∀ acc, (Id.run do List.forIn.loop (fun a b => ForInStep.yield (b ++ treeToList a)) children acc) = acc ++ List.join (List.map treeToList children) := by intro acc induction children generalizing acc with simp [List.forIn.loop, List.map, List.join, Id.run]