From ccef9588ae95ca09b92d20cfedd68417f2bd48d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20R=C3=B3=C5=BCowski?= Date: Thu, 26 Mar 2026 14:55:40 +0000 Subject: [PATCH] feat: add further `cbv` annotations (#13135) This PR adds several `cbv_opaque` and `cbv_eval` annotations to the standard library. --- src/Init/Data/Array/MinMax.lean | 4 +- .../Data/Iterators/Consumers/Collect.lean | 2 +- src/Init/Data/Iterators/Consumers/Loop.lean | 6 +-- .../Iterators/Lemmas/Consumers/Collect.lean | 1 + .../Data/Iterators/Lemmas/Consumers/Loop.lean | 4 +- src/Init/Data/Slice/Array/Lemmas.lean | 1 + src/Std/Data/TreeMap/Iterator.lean | 4 +- tests/elab/cbv_annotations2.lean | 15 ++++++ tests/elab/cbv_annotations3.lean | 46 +++++++++++++++++++ tests/elab/cbv_annotations4.lean | 9 ++++ tests/elab/cbv_annotations5.lean | 29 ++++++++++++ tests/elab/cbv_annotations6.lean | 38 +++++++++++++++ 12 files changed, 150 insertions(+), 9 deletions(-) create mode 100644 tests/elab/cbv_annotations2.lean create mode 100644 tests/elab/cbv_annotations3.lean create mode 100644 tests/elab/cbv_annotations4.lean create mode 100644 tests/elab/cbv_annotations5.lean create mode 100644 tests/elab/cbv_annotations6.lean diff --git a/src/Init/Data/Array/MinMax.lean b/src/Init/Data/Array/MinMax.lean index 6bb8a60593..297bbf7096 100644 --- a/src/Init/Data/Array/MinMax.lean +++ b/src/Init/Data/Array/MinMax.lean @@ -113,7 +113,7 @@ public theorem _root_.List.min?_toArray [Min α] {l : List α} : · simp [List.min_toArray, List.min_eq_get_min?, - List.get_min?] · simp_all -@[simp, grind =] +@[simp, grind =, cbv_eval ←] public theorem min?_toList [Min α] {xs : Array α} : xs.toList.min? = xs.min? := by cases xs; simp @@ -153,7 +153,7 @@ public theorem _root_.List.max?_toArray [Max α] {l : List α} : · simp [List.max_toArray, List.max_eq_get_max?, - List.get_max?] · simp_all -@[simp, grind =] +@[simp, grind =, cbv_eval ←] public theorem max?_toList [Max α] {xs : Array α} : xs.toList.max? = xs.max? := by cases xs; simp diff --git a/src/Init/Data/Iterators/Consumers/Collect.lean b/src/Init/Data/Iterators/Consumers/Collect.lean index cecd907a78..329917d62f 100644 --- a/src/Init/Data/Iterators/Consumers/Collect.lean +++ b/src/Init/Data/Iterators/Consumers/Collect.lean @@ -66,7 +66,7 @@ lists are prepend-only, this `toListRev` is usually more efficient that `toList` If the iterator is not finite, this function might run forever. The variant `it.ensureTermination.toListRev` always terminates after finitely many steps. -/ -@[always_inline, inline] +@[always_inline, inline, cbv_opaque] def Iter.toListRev {α : Type w} {β : Type w} [Iterator α Id β] (it : Iter (α := α) β) : List β := it.toIterM.toListRev.run diff --git a/src/Init/Data/Iterators/Consumers/Loop.lean b/src/Init/Data/Iterators/Consumers/Loop.lean index 46a6b6b950..def18caf45 100644 --- a/src/Init/Data/Iterators/Consumers/Loop.lean +++ b/src/Init/Data/Iterators/Consumers/Loop.lean @@ -226,7 +226,7 @@ any element emitted by the iterator {name}`it`. {lit}`O(|xs|)`. Short-circuits upon encountering the first match. The elements in {name}`it` are examined in order of iteration. -/ -@[inline] +@[inline, cbv_opaque] def Iter.any {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id] (p : β → Bool) (it : Iter (α := α) β) : Bool := @@ -292,7 +292,7 @@ all element emitted by the iterator {name}`it`. {lit}`O(|xs|)`. Short-circuits upon encountering the first match. The elements in {name}`it` are examined in order of iteration. -/ -@[inline] +@[inline, cbv_opaque] def Iter.all {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id] (p : β → Bool) (it : Iter (α := α) β) : Bool := @@ -644,7 +644,7 @@ Examples: * `[7, 6].iter.first? = some 7` * `[].iter.first? = none` -/ -@[inline] +@[inline, cbv_opaque] def Iter.first? {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id] (it : Iter (α := α) β) : Option β := it.toIterM.first?.run diff --git a/src/Init/Data/Iterators/Lemmas/Consumers/Collect.lean b/src/Init/Data/Iterators/Lemmas/Consumers/Collect.lean index a5ac30aa8c..59cdf72c3d 100644 --- a/src/Init/Data/Iterators/Lemmas/Consumers/Collect.lean +++ b/src/Init/Data/Iterators/Lemmas/Consumers/Collect.lean @@ -110,6 +110,7 @@ theorem Iter.reverse_toListRev_ensureTermination [Iterator α Id β] [Finite α it.ensureTermination.toListRev.reverse = it.toList := by simp +@[cbv_eval] theorem Iter.toListRev_eq {α β} [Iterator α Id β] [Finite α Id] {it : Iter (α := α) β} : it.toListRev = it.toList.reverse := by diff --git a/src/Init/Data/Iterators/Lemmas/Consumers/Loop.lean b/src/Init/Data/Iterators/Lemmas/Consumers/Loop.lean index 43fe66f310..4a501b5d86 100644 --- a/src/Init/Data/Iterators/Lemmas/Consumers/Loop.lean +++ b/src/Init/Data/Iterators/Lemmas/Consumers/Loop.lean @@ -637,6 +637,7 @@ theorem Iter.any_eq_forIn {α β : Type w} [Iterator α Id β] return .yield false)).run := by simp [any_eq_anyM, anyM_eq_forIn] +@[cbv_eval ←] theorem Iter.any_toList {α β : Type w} [Iterator α Id β] [Finite α Id] [IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} {p : β → Bool} : @@ -727,6 +728,7 @@ theorem Iter.all_eq_forIn {α β : Type w} [Iterator α Id β] return .done false)).run := by simp [all_eq_allM, allM_eq_forIn] +@[cbv_eval ←] theorem Iter.all_toList {α β : Type w} [Iterator α Id β] [Finite α Id] [IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} {p : β → Bool} : @@ -954,7 +956,7 @@ theorem Iter.first?_eq_match_step {α β : Type w} [Iterator α Id β] [Iterator generalize it.toIterM.step.run.inflate = s rcases s with ⟨_|_|_, _⟩ <;> simp [Iter.first?_eq_first?_toIterM] -@[simp, grind =] +@[simp, grind =, cbv_eval ←] theorem Iter.head?_toList {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id] [Finite α Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} : it.toList.head? = it.first? := by diff --git a/src/Init/Data/Slice/Array/Lemmas.lean b/src/Init/Data/Slice/Array/Lemmas.lean index 5cc4c2174c..0aee046bd4 100644 --- a/src/Init/Data/Slice/Array/Lemmas.lean +++ b/src/Init/Data/Slice/Array/Lemmas.lean @@ -193,6 +193,7 @@ public theorem Array.toSubarray_eq_toSubarray_of_min_eq_min {xs : Array α} simp [*]; omega · simp +@[cbv_eval] public theorem Array.toSubarray_eq_min {xs : Array α} {lo hi : Nat} : xs.toSubarray lo hi = ⟨⟨xs, min lo (min hi xs.size), min hi xs.size, Nat.min_le_right _ _, Nat.min_le_right _ _⟩⟩ := by diff --git a/src/Std/Data/TreeMap/Iterator.lean b/src/Std/Data/TreeMap/Iterator.lean index 9526c2e5c4..51cb16f60c 100644 --- a/src/Std/Data/TreeMap/Iterator.lean +++ b/src/Std/Data/TreeMap/Iterator.lean @@ -26,7 +26,7 @@ The iterator yields the elements of the map in order and then terminates. * `Finite` instance: always * `Productive` instance: always -/ -@[inline] +@[inline, cbv_opaque] public def iter {α : Type u} {β : Type v} {cmp : α → α → Ordering} (m : TreeMap α β cmp) := (m.inner.iter.map fun e => (e.1, e.2) : Iter (α × β)) @@ -63,7 +63,7 @@ public def valuesIter {α : Type u} {β : Type u} {cmp : α → α → Ordering} (m : TreeMap α β cmp) := m.inner.valuesIter -@[simp] +@[simp, cbv_eval] public theorem toList_iter {cmp : α → α → Ordering} (m : TreeMap α β cmp) : m.iter.toList = m.toList := by simp only [iter, Iter.toList_map, DTreeMap.toList_iter, DTreeMap.toList, diff --git a/tests/elab/cbv_annotations2.lean b/tests/elab/cbv_annotations2.lean new file mode 100644 index 0000000000..6d03e40742 --- /dev/null +++ b/tests/elab/cbv_annotations2.lean @@ -0,0 +1,15 @@ +module +import Std + +open Std + +def isPalindrome (s : String) : Bool := + s.chars.zip s.revChars |>.all (fun p => p.1 == p.2) + +example : isPalindrome "" = true := by cbv +example : isPalindrome "aba" = true := by cbv +example : isPalindrome "aaaaa" = true := by cbv +example : isPalindrome "zbcd" = false := by cbv +example : isPalindrome "xywyx" = true := by cbv +example : isPalindrome "xywyz" = false := by cbv +example : isPalindrome "xywxz" = false := by cbv diff --git a/tests/elab/cbv_annotations3.lean b/tests/elab/cbv_annotations3.lean new file mode 100644 index 0000000000..5413571f3b --- /dev/null +++ b/tests/elab/cbv_annotations3.lean @@ -0,0 +1,46 @@ +module + +public import Std +public import Init.Data.Iterators.Lemmas.Basic +open Std + +public section + +def frequencies (xs : List Nat) : TreeMap Nat Nat (fun a b => compare b a) := + xs.foldl (init := ∅) (fun freq (x : Nat) => freq.alter x (fun v? => some (v?.getD 0 + 1))) + +def search (xs : List Nat) : Int := + let frequencies := frequencies xs + let kv := frequencies.iter + |>.filter (fun (k, v) => 0 < k ∧ k ≤ v) + |>.map (fun (k, _) => k) + |>.first? + kv.getD (-1) + +/-! ## Tests -/ + +example : search [5, 5, 5, 5, 1] = 1 := by cbv +example : search [4, 1, 4, 1, 4, 4] = 4 := by cbv +example : search [3, 3] = -1 := by cbv +example : search [8, 8, 8, 8, 8, 8, 8, 8] = 8 := by cbv +example : search [2, 3, 3, 2, 2] = 2 := by cbv +example : search [2, 7, 8, 8, 4, 8, 7, 3, 9, 6, 5, 10, 4, 3, 6, 7, 1, 7, 4, 10, 8, 1] = 1 := by cbv +example : search [3, 2, 8, 2] = 2 := by cbv +example : search [6, 7, 1, 8, 8, 10, 5, 8, 5, 3, 10] = 1 := by cbv +example : search [8, 8, 3, 6, 5, 6, 4] = -1 := by cbv +example : search [6, 9, 6, 7, 1, 4, 7, 1, 8, 8, 9, 8, 10, 10, 8, 4, 10, 4, 10, 1, 2, 9, 5, 7, 9] = 1 := by cbv +example : search [1, 9, 10, 1, 3] = 1 := by cbv +example : search [6, 9, 7, 5, 8, 7, 5, 3, 7, 5, 10, 10, 3, 6, 10, 2, 8, 6, 5, 4, 9, 5, 3, 10] = 5 := by cbv +example : search [1] = 1 := by cbv +example : search [8, 8, 10, 6, 4, 3, 5, 8, 2, 4, 2, 8, 4, 6, 10, 4, 2, 1, 10, 2, 1, 1, 5] = 4 := by cbv +example : search [2, 10, 4, 8, 2, 10, 5, 1, 2, 9, 5, 5, 6, 3, 8, 6, 4, 10] = 2 := by cbv +example : search [1, 6, 10, 1, 6, 9, 10, 8, 6, 8, 7, 3] = 1 := by cbv +example : search [9, 2, 4, 1, 5, 1, 5, 2, 5, 7, 7, 7, 3, 10, 1, 5, 4, 2, 8, 4, 1, 9, 10, 7, 10, 2, 8, 10, 9, 4] = 4 := by cbv +example : search [2, 6, 4, 2, 8, 7, 5, 6, 4, 10, 4, 6, 3, 7, 8, 8, 3, 1, 4, 2, 2, 10, 7] = 4 := by cbv +example : search [9, 8, 6, 10, 2, 6, 10, 2, 7, 8, 10, 3, 8, 2, 6, 2, 3, 1] = 2 := by cbv +example : search [5, 5, 3, 9, 5, 6, 3, 2, 8, 5, 6, 10, 10, 6, 8, 4, 10, 7, 7, 10, 8] = -1 := by cbv +example : search [10] = -1 := by cbv +example : search [9, 7, 7, 2, 4, 7, 2, 10, 9, 7, 5, 7, 2] = 2 := by cbv +example : search [5, 4, 10, 2, 1, 1, 10, 3, 6, 1, 8] = 1 := by cbv +example : search [7, 9, 9, 9, 3, 4, 1, 5, 9, 1, 2, 1, 1, 10, 7, 5, 6, 7, 6, 7, 7, 6] = 1 := by cbv +example : search [3, 10, 10, 9, 2] = -1 := by cbv diff --git a/tests/elab/cbv_annotations4.lean b/tests/elab/cbv_annotations4.lean new file mode 100644 index 0000000000..cd97eddfb3 --- /dev/null +++ b/tests/elab/cbv_annotations4.lean @@ -0,0 +1,9 @@ +module + +public import Init.Data.Array.MinMax + +def maxElement (xs : Array Int) : Int := + xs.max?.getD 0 + +example : maxElement #[1, 2, 3] = 3 := by cbv +example : maxElement #[5, 3, -5, 2, -3, 3, 9, 0, 124, 1, -10] = 124 := by cbv diff --git a/tests/elab/cbv_annotations5.lean b/tests/elab/cbv_annotations5.lean new file mode 100644 index 0000000000..1a656fa39d --- /dev/null +++ b/tests/elab/cbv_annotations5.lean @@ -0,0 +1,29 @@ +module + +import Std.Data.Iterators + +@[grind =] +def makeAPile₁ (n : Nat) : List Nat := + (*...n).iter.map (n + 2 * ·) |>.toList + +@[grind =] +def makeAPile₂ (n : Nat) : List Nat := + (*...n).iter.map (fun i => n + 2 * (n - 1 - i)) |>.toListRev + +example : makeAPile₁ 0 = [] := by cbv +example : makeAPile₁ 1 = [1] := by cbv +example : makeAPile₁ 2 = [2, 4] := by cbv +example : makeAPile₁ 3 = [3, 5, 7] := by cbv +example : makeAPile₁ 4 = [4, 6, 8, 10] := by cbv +example : makeAPile₁ 5 = [5, 7, 9, 11, 13] := by cbv +example : makeAPile₁ 6 = [6, 8, 10, 12, 14, 16] := by cbv +example : makeAPile₁ 8 = [8, 10, 12, 14, 16, 18, 20, 22] := by cbv + +example : makeAPile₂ 0 = [] := by cbv +example : makeAPile₂ 1 = [1] := by cbv +example : makeAPile₂ 2 = [2, 4] := by cbv +example : makeAPile₂ 3 = [3, 5, 7] := by cbv +example : makeAPile₂ 4 = [4, 6, 8, 10] := by cbv +example : makeAPile₂ 5 = [5, 7, 9, 11, 13] := by cbv +example : makeAPile₂ 6 = [6, 8, 10, 12, 14, 16] := by cbv +example : makeAPile₂ 8 = [8, 10, 12, 14, 16, 18, 20, 22] := by cbv diff --git a/tests/elab/cbv_annotations6.lean b/tests/elab/cbv_annotations6.lean new file mode 100644 index 0000000000..e5b61c15ce --- /dev/null +++ b/tests/elab/cbv_annotations6.lean @@ -0,0 +1,38 @@ +module + +import Std +open Std Std.Do + +set_option mvcgen.warning false + +def isSorted (xs : Array Nat) : Bool := Id.run do + if h : xs.size > 0 then + let mut last := xs[0] + let mut repeated := false + for x in xs[1...*] do + match compare last x with + | .lt => + last := x + repeated := false + | .eq => + if repeated then + return false + else + repeated := true + | .gt => + return false + return true + +example : isSorted #[5] = true := by cbv +example : isSorted #[1, 2, 3, 4, 5] = true := by cbv +example : isSorted #[1, 3, 2, 4, 5] = false := by cbv +example : isSorted #[1, 2, 3, 4, 5, 6] = true := by cbv +example : isSorted #[1, 2, 3, 4, 5, 6, 7] = true := by cbv +example : isSorted #[1, 3, 2, 4, 5, 6, 7] = false := by cbv +example : isSorted #[] = true := by cbv +example : isSorted #[1] = true := by cbv +example : isSorted #[3, 2, 1] = false := by cbv +example : isSorted #[1, 2, 2, 2, 3, 4] = false := by cbv +example : isSorted #[1, 2, 3, 3, 3, 4] = false := by cbv +example : isSorted #[1, 2, 2, 3, 3, 4] = true := by cbv +example : isSorted #[1, 2, 3, 4] = true := by cbv