diff --git a/src/Init/Control/Basic.lean b/src/Init/Control/Basic.lean index 219130d1e1..636fd7245d 100644 --- a/src/Init/Control/Basic.lean +++ b/src/Init/Control/Basic.lean @@ -11,8 +11,13 @@ universe u v w /-- A `ForIn'` instance, which handles `for h : x in c do`, can also handle `for x in x do` by ignoring `h`, and so provides a `ForIn` instance. + +Note that this instance will cause a potentially non-defeq duplication if both `ForIn` and `ForIn'` +instances are provided for the same type. -/ -instance (priority := low) instForInOfForIn' [ForIn' m ρ α d] : ForIn m ρ α where +-- We set the priority to 500 so it is below the default, +-- but still above the low priority instance from `Stream`. +instance (priority := 500) instForInOfForIn' [ForIn' m ρ α d] : ForIn m ρ α where forIn x b f := forIn' x b fun a _ => f a @[simp] theorem forIn'_eq_forIn [d : Membership α ρ] [ForIn' m ρ α d] {β} [Monad m] (x : ρ) (b : β) diff --git a/src/Init/Data/Array/Basic.lean b/src/Init/Data/Array/Basic.lean index 39e7d761ab..3e520e51b4 100644 --- a/src/Init/Data/Array/Basic.lean +++ b/src/Init/Data/Array/Basic.lean @@ -302,37 +302,6 @@ def modifyOp (self : Array α) (idx : Nat) (f : α → α) : Array α := We claim this unsafe implementation is correct because an array cannot have more than `usizeSz` elements in our runtime. This kind of low level trick can be removed with a little bit of compiler support. For example, if the compiler simplifies `as.size < usizeSz` to true. -/ -@[inline] unsafe def forInUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : α → β → m (ForInStep β)) : m β := - let sz := as.usize - let rec @[specialize] loop (i : USize) (b : β) : m β := do - if i < sz then - let a := as.uget i lcProof - match (← f a b) with - | ForInStep.done b => pure b - | ForInStep.yield b => loop (i+1) b - else - pure b - loop 0 b - -/-- Reference implementation for `forIn` -/ -@[implemented_by Array.forInUnsafe] -protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : α → β → m (ForInStep β)) : m β := - let rec loop (i : Nat) (h : i ≤ as.size) (b : β) : m β := do - match i, h with - | 0, _ => pure b - | i+1, h => - have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h - have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide) - have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this - match (← f as[as.size - 1 - i] b) with - | ForInStep.done b => pure b - | ForInStep.yield b => loop i (Nat.le_of_lt h') b - loop as.size (Nat.le_refl _) b - -instance : ForIn m (Array α) α where - forIn := Array.forIn - -/-- See comment at `forInUnsafe` -/ @[inline] unsafe def forIn'Unsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : (a : α) → a ∈ as → β → m (ForInStep β)) : m β := let sz := as.usize let rec @[specialize] loop (i : USize) (b : β) : m β := do @@ -363,7 +332,9 @@ protected def forIn' {α : Type u} {β : Type v} {m : Type v → Type w} [Monad instance : ForIn' m (Array α) α inferInstance where forIn' := Array.forIn' -/-- See comment at `forInUnsafe` -/ +-- No separate `ForIn` instance is required because it can be derived from `ForIn'`. + +/-- See comment at `forIn'Unsafe` -/ @[inline] unsafe def foldlMUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : β → α → m β) (init : β) (as : Array α) (start := 0) (stop := as.size) : m β := let rec @[specialize] fold (i : USize) (stop : USize) (b : β) : m β := do @@ -398,7 +369,7 @@ def foldlM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : β else fold as.size (Nat.le_refl _) -/-- See comment at `forInUnsafe` -/ +/-- See comment at `forIn'Unsafe` -/ @[inline] unsafe def foldrMUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α → β → m β) (init : β) (as : Array α) (start := as.size) (stop := 0) : m β := let rec @[specialize] fold (i : USize) (stop : USize) (b : β) : m β := do @@ -437,7 +408,7 @@ def foldrM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α else pure init -/-- See comment at `forInUnsafe` -/ +/-- See comment at `forIn'Unsafe` -/ @[inline] unsafe def mapMUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α → m β) (as : Array α) : m (Array β) := let sz := as.usize diff --git a/src/Init/Data/Array/Lemmas.lean b/src/Init/Data/Array/Lemmas.lean index 474491cb3b..c173fd0afb 100644 --- a/src/Init/Data/Array/Lemmas.lean +++ b/src/Init/Data/Array/Lemmas.lean @@ -104,25 +104,6 @@ We prefer to pull `List.toArray` outwards. @[simp] theorem back_toArray [Inhabited α] (l : List α) : l.toArray.back = l.getLast! := by simp only [back, size_toArray, Array.get!_eq_getElem!, getElem!_toArray, getLast!_eq_getElem!] -@[simp] theorem forIn_loop_toArray [Monad m] (l : List α) (f : α → β → m (ForInStep β)) (i : Nat) - (h : i ≤ l.length) (b : β) : - Array.forIn.loop l.toArray f i h b = (l.drop (l.length - i)).forIn b f := by - induction i generalizing l b with - | zero => simp [Array.forIn.loop] - | succ i ih => - simp only [Array.forIn.loop, size_toArray, getElem_toArray, ih, forIn_eq_forIn] - rw [Nat.sub_add_eq, List.drop_sub_one (by omega), List.getElem?_eq_getElem (by omega)] - simp only [Option.toList_some, singleton_append, forIn_cons] - have t : l.length - 1 - i = l.length - i - 1 := by omega - simp only [t] - congr - -@[simp] theorem forIn_toArray [Monad m] (l : List α) (b : β) (f : α → β → m (ForInStep β)) : - forIn l.toArray b f = forIn l b f := by - change l.toArray.forIn b f = l.forIn b f - rw [Array.forIn, forIn_loop_toArray] - simp - @[simp] theorem forIn'_loop_toArray [Monad m] (l : List α) (f : (a : α) → a ∈ l.toArray → β → m (ForInStep β)) (i : Nat) (h : i ≤ l.length) (b : β) : Array.forIn'.loop l.toArray f i h b = @@ -131,7 +112,7 @@ We prefer to pull `List.toArray` outwards. | zero => simp [Array.forIn'.loop] | succ i ih => - simp only [Array.forIn'.loop, size_toArray, getElem_toArray, ih, forIn_eq_forIn] + simp only [Array.forIn'.loop, size_toArray, getElem_toArray, ih] have t : drop (l.length - (i + 1)) l = l[l.length - i - 1] :: drop (l.length - i) l := by simp only [Nat.sub_add_eq] rw [List.drop_sub_one (by omega), List.getElem?_eq_getElem (by omega)] @@ -145,7 +126,11 @@ We prefer to pull `List.toArray` outwards. forIn' l.toArray b f = forIn' l b (fun a m b => f a (mem_toArray.mpr m) b) := by change Array.forIn' _ _ _ = List.forIn' _ _ _ rw [Array.forIn', forIn'_loop_toArray] - simp [List.forIn_eq_forIn] + simp + +@[simp] theorem forIn_toArray [Monad m] (l : List α) (b : β) (f : α → β → m (ForInStep β)) : + forIn l.toArray b f = forIn l b f := by + simpa using forIn'_toArray l b fun a m b => f a b theorem foldrM_toArray [Monad m] (f : α → β → m β) (init : β) (l : List α) : l.toArray.foldrM f init = l.foldrM f init := by diff --git a/src/Init/Data/List/Control.lean b/src/Init/Data/List/Control.lean index 2f945829ac..dd17b84167 100644 --- a/src/Init/Data/List/Control.lean +++ b/src/Init/Data/List/Control.lean @@ -215,27 +215,6 @@ def findSomeM? {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f | some b => pure (some b) | none => findSomeM? f as -@[inline] protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : List α) (init : β) (f : α → β → m (ForInStep β)) : m β := - let rec @[specialize] loop - | [], b => pure b - | a::as, b => do - match (← f a b) with - | ForInStep.done b => pure b - | ForInStep.yield b => loop as b - loop as init - -instance : ForIn m (List α) α where - forIn := List.forIn - -@[simp] theorem forIn_eq_forIn [Monad m] : @List.forIn α β m _ = forIn := rfl - -@[simp] theorem forIn_nil [Monad m] (f : α → β → m (ForInStep β)) (b : β) : forIn [] b f = pure b := - rfl - -@[simp] theorem forIn_cons [Monad m] (f : α → β → m (ForInStep β)) (a : α) (as : List α) (b : β) - : forIn (a::as) b f = f a b >>= fun | ForInStep.done b => pure b | ForInStep.yield b => forIn as b f := - rfl - @[inline] protected def forIn' {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : List α) (init : β) (f : (a : α) → a ∈ as → β → m (ForInStep β)) : m β := let rec @[specialize] loop : (as' : List α) → (b : β) → Exists (fun bs => bs ++ as' = as) → m β | [], b, _ => pure b @@ -254,16 +233,15 @@ instance : ForIn m (List α) α where instance : ForIn' m (List α) α inferInstance where forIn' := List.forIn' +-- No separate `ForIn` instance is required because it can be derived from `ForIn'`. + @[simp] theorem forIn'_eq_forIn' [Monad m] : @List.forIn' α β m _ = forIn' := rfl -@[simp] theorem forIn'_eq_forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : List α) (init : β) (f : α → β → m (ForInStep β)) : forIn' as init (fun a _ b => f a b) = forIn as init f := by - simp [forIn', forIn, List.forIn, List.forIn'] - have : ∀ cs h, List.forIn'.loop cs (fun a _ b => f a b) as init h = List.forIn.loop f as init := by - intro cs h - induction as generalizing cs init with - | nil => intros; rfl - | cons a as ih => intros; simp [List.forIn.loop, List.forIn'.loop, ih] - apply this +@[simp] theorem forIn'_nil [Monad m] (f : (a : α) → a ∈ [] → β → m (ForInStep β)) (b : β) : forIn' [] b f = pure b := + rfl + +@[simp] theorem forIn_nil [Monad m] (f : α → β → m (ForInStep β)) (b : β) : forIn [] b f = pure b := + rfl instance : ForM m (List α) α where forM := List.forM diff --git a/src/Init/Data/List/Monadic.lean b/src/Init/Data/List/Monadic.lean index f8ded7d6f2..5836d3d8ff 100644 --- a/src/Init/Data/List/Monadic.lean +++ b/src/Init/Data/List/Monadic.lean @@ -89,9 +89,6 @@ theorem mapM_eq_reverse_foldlM_cons [Monad m] [LawfulMonad m] (f : α → m β) /-! ### forIn' -/ -@[simp] theorem forIn'_nil [Monad m] (f : (a : α) → a ∈ [] → β → m (ForInStep β)) (b : β) : forIn' [] b f = pure b := - rfl - theorem forIn'_loop_congr [Monad m] {as bs : List α} {f : (a' : α) → a' ∈ as → β → m (ForInStep β)} {g : (a' : α) → a' ∈ bs → β → m (ForInStep β)} @@ -122,6 +119,11 @@ theorem forIn'_loop_congr [Monad m] {as bs : List α} intros rfl +@[simp] theorem forIn_cons [Monad m] (f : α → β → m (ForInStep β)) (a : α) (as : List α) (b : β) : + forIn (a::as) b f = f a b >>= fun | ForInStep.done b => pure b | ForInStep.yield b => forIn as b f := by + have := forIn'_cons (a := a) (as := as) (fun a' _ b => f a' b) b + simpa only [forIn'_eq_forIn] + @[congr] theorem forIn'_congr [Monad m] {as bs : List α} (w : as = bs) {b b' : β} (hb : b = b') {f : (a' : α) → a' ∈ as → β → m (ForInStep β)} diff --git a/src/Init/Data/Option/Instances.lean b/src/Init/Data/Option/Instances.lean index cfce9a4038..580a7fb3ad 100644 --- a/src/Init/Data/Option/Instances.lean +++ b/src/Init/Data/Option/Instances.lean @@ -86,4 +86,6 @@ instance : ForIn' m (Option α) α inferInstance where match ← f a rfl init with | .done r | .yield r => return r +-- No separate `ForIn` instance is required because it can be derived from `ForIn'`. + end Option diff --git a/src/Init/Data/Range.lean b/src/Init/Data/Range.lean index 0de8c92228..55451cf1de 100644 --- a/src/Init/Data/Range.lean +++ b/src/Init/Data/Range.lean @@ -20,21 +20,6 @@ instance : Membership Nat Range where namespace Range universe u v -@[inline] protected def forIn {β : Type u} {m : Type u → Type v} [Monad m] (range : Range) (init : β) (f : Nat → β → m (ForInStep β)) : m β := - -- pass `stop` and `step` separately so the `range` object can be eliminated through inlining - let rec @[specialize] loop (fuel i stop step : Nat) (b : β) : m β := do - if i ≥ stop then - return b - else match fuel with - | 0 => pure b - | fuel+1 => match (← f i b) with - | ForInStep.done b => pure b - | ForInStep.yield b => loop fuel (i + step) stop step b - loop range.stop range.start range.stop range.step init - -instance : ForIn m Range Nat where - forIn := Range.forIn - @[inline] protected def forIn' {β : Type u} {m : Type u → Type v} [Monad m] (range : Range) (init : β) (f : (i : Nat) → i ∈ range → β → m (ForInStep β)) : m β := let rec @[specialize] loop (start stop step : Nat) (f : (i : Nat) → start ≤ i ∧ i < stop → β → m (ForInStep β)) (fuel i : Nat) (hl : start ≤ i) (b : β) : m β := do if hu : i < stop then @@ -50,6 +35,8 @@ instance : ForIn m Range Nat where instance : ForIn' m Range Nat inferInstance where forIn' := Range.forIn' +-- No separate `ForIn` instance is required because it can be derived from `ForIn'`. + @[inline] protected def forM {m : Type u → Type v} [Monad m] (range : Range) (f : Nat → m PUnit) : m PUnit := let rec @[specialize] loop (fuel i stop step : Nat) : m PUnit := do if i ≥ stop then diff --git a/src/Lean/Data/KVMap.lean b/src/Lean/Data/KVMap.lean index 60db90bbb7..f56a30e5fd 100644 --- a/src/Lean/Data/KVMap.lean +++ b/src/Lean/Data/KVMap.lean @@ -177,7 +177,7 @@ def updateSyntax (m : KVMap) (k : Name) (f : Syntax → Syntax) : KVMap := @[inline] protected def forIn.{w, w'} {δ : Type w} {m : Type w → Type w'} [Monad m] (kv : KVMap) (init : δ) (f : Name × DataValue → δ → m (ForInStep δ)) : m δ := - kv.entries.forIn init f + forIn kv.entries init f instance : ForIn m KVMap (Name × DataValue) where forIn := KVMap.forIn diff --git a/src/Std/Data/DHashMap/Raw.lean b/src/Std/Data/DHashMap/Raw.lean index 1c01c2a4cd..08bbc72035 100644 --- a/src/Std/Data/DHashMap/Raw.lean +++ b/src/Std/Data/DHashMap/Raw.lean @@ -334,7 +334,7 @@ map in some order. /-- Support for the `for` loop construct in `do` blocks. -/ @[inline] def forIn (f : (a : α) → β a → δ → m (ForInStep δ)) (init : δ) (b : Raw α β) : m δ := - b.buckets.forIn init (fun bucket acc => bucket.forInStep acc f) + ForIn.forIn b.buckets init (fun bucket acc => bucket.forInStep acc f) instance : ForM m (Raw α β) ((a : α) × β a) where forM m f := m.forM (fun a b => f ⟨a, b⟩) diff --git a/src/lake/Lake/Util/OrdHashSet.lean b/src/lake/Lake/Util/OrdHashSet.lean index fea3d04a59..003af356e1 100644 --- a/src/lake/Lake/Util/OrdHashSet.lean +++ b/src/lake/Lake/Util/OrdHashSet.lean @@ -69,6 +69,6 @@ def ofArray (arr : Array α) : OrdHashSet α := self.toArray.forM f @[inline] protected def forIn [Monad m] (self : OrdHashSet α) (init : β) (f : α → β → m (ForInStep β)) : m β := - self.toArray.forIn init f + ForIn.forIn self.toArray init f instance : ForIn m (OrdHashSet α) α := ⟨OrdHashSet.forIn⟩ diff --git a/src/lake/Lake/Util/RBArray.lean b/src/lake/Lake/Util/RBArray.lean index a0a18640dd..3794801f6c 100644 --- a/src/lake/Lake/Util/RBArray.lean +++ b/src/lake/Lake/Util/RBArray.lean @@ -63,7 +63,7 @@ def insert (self : RBArray α β cmp) (a : α) (b : β) : RBArray α β cmp := self.toArray.forM f @[inline] protected def forIn [Monad m] (self : RBArray α β cmp) (init : σ) (f : β → σ → m (ForInStep σ)) : m σ := - self.toArray.forIn init f + ForIn.forIn self.toArray init f instance : ForIn m (RBArray α β cmp) β := ⟨RBArray.forIn⟩ diff --git a/tests/bench/liasolver.lean b/tests/bench/liasolver.lean index 8a5633b8e2..adcfab670f 100644 --- a/tests/bench/liasolver.lean +++ b/tests/bench/liasolver.lean @@ -53,7 +53,7 @@ namespace Lean.HashMap @[inline] protected def forIn {δ : Type w} {m : Type w → Type w'} [Monad m] (as : HashMap α β) (init : δ) (f : (α × β) → δ → m (ForInStep δ)) : m δ := do - as.val.buckets.val.forIn init fun bucket acc => do + forIn as.val.buckets.val init fun bucket acc => do let (done, v) ← bucket.forIn (false, acc) fun v (_, acc) => do let r ← f v acc match r with diff --git a/tests/lean/run/array_simp.lean b/tests/lean/run/array_simp.lean index d0ed4dd48d..4e6ff093b1 100644 --- a/tests/lean/run/array_simp.lean +++ b/tests/lean/run/array_simp.lean @@ -10,7 +10,7 @@ attribute [local simp] Id.run in #check_simp (Id.run do let mut s := 0 - for i in [1,2,3,4].toArray do + for i in #[1,2,3,4] do s := s + i pure s) ~> 10 @@ -18,6 +18,6 @@ attribute [local simp] Id.run in #check_simp (Id.run do let mut s := 0 - for h : i in [1,2,3,4].toArray do + for h : i in #[1,2,3,4] do s := s + i pure s) ~> 10 diff --git a/tests/lean/run/list_simp.lean b/tests/lean/run/list_simp.lean index 6fd985face..ed4f89c716 100644 --- a/tests/lean/run/list_simp.lean +++ b/tests/lean/run/list_simp.lean @@ -461,6 +461,21 @@ end Pairwise /-! ### max? -/ /-! ## Monadic operations -/ +attribute [local simp] Id.run in +#check_simp + (Id.run do + let mut s := 0 + for i in [1,2,3,4] do + s := s + i + pure s) ~> 10 + +attribute [local simp] Id.run in +#check_simp + (Id.run do + let mut s := 0 + for h : i in [1,2,3,4] do + s := s + i + pure s) ~> 10 /-! ### mapM -/ diff --git a/tests/lean/run/treeNode.lean b/tests/lean/run/treeNode.lean index 961386d87b..7d333da8f6 100644 --- a/tests/lean/run/treeNode.lean +++ b/tests/lean/run/treeNode.lean @@ -13,13 +13,13 @@ def treeToList (t : TreeNode) : List String := r := r ++ treeToList child return r -@[simp] theorem treeToList_eq (name : String) (children : List TreeNode) : treeToList (.mkNode name children) = name :: List.flatten (children.map treeToList) := by - simp only [treeToList, Id.run, Id.pure_eq, Id.bind_eq, List.forIn'_eq_forIn, forIn, List.forIn] - have : ∀ acc, (Id.run do List.forIn.loop (fun a b => ForInStep.yield (b ++ treeToList a)) children acc) = acc ++ List.flatten (List.map treeToList children) := by - intro acc - induction children generalizing acc with simp [List.forIn.loop, List.map, List.flatten, Id.run] - | cons c cs ih => simp [Id.run] at ih; simp [ih, List.append_assoc] - apply this +@[simp] theorem treeToList_eq (name : String) (children : List TreeNode) : treeToList (.mkNode name children) = name :: List.flatten (children.map treeToList) := by + simp [treeToList, Id.run] + conv => rhs; rw [← List.singleton_append] + generalize [name] = as + induction children generalizing as with + | nil => simp + | cons c cs ih => simp [ih, List.append_assoc] mutual def numNames : TreeNode → Nat