feat: add further cbv annotations (#13135)

This PR adds several `cbv_opaque` and `cbv_eval` annotations to the
standard library.
This commit is contained in:
Wojciech Różowski 2026-03-26 14:55:40 +00:00 committed by GitHub
parent a8bbc95d9f
commit ccef9588ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 150 additions and 9 deletions

View file

@ -113,7 +113,7 @@ public theorem _root_.List.min?_toArray [Min α] {l : List α} :
· simp [List.min_toArray, List.min_eq_get_min?, - List.get_min?]
· simp_all
@[simp, grind =]
@[simp, grind =, cbv_eval ←]
public theorem min?_toList [Min α] {xs : Array α} :
xs.toList.min? = xs.min? := by
cases xs; simp
@ -153,7 +153,7 @@ public theorem _root_.List.max?_toArray [Max α] {l : List α} :
· simp [List.max_toArray, List.max_eq_get_max?, - List.get_max?]
· simp_all
@[simp, grind =]
@[simp, grind =, cbv_eval ←]
public theorem max?_toList [Max α] {xs : Array α} :
xs.toList.max? = xs.max? := by
cases xs; simp

View file

@ -66,7 +66,7 @@ lists are prepend-only, this `toListRev` is usually more efficient that `toList`
If the iterator is not finite, this function might run forever. The variant
`it.ensureTermination.toListRev` always terminates after finitely many steps.
-/
@[always_inline, inline]
@[always_inline, inline, cbv_opaque]
def Iter.toListRev {α : Type w} {β : Type w}
[Iterator α Id β] (it : Iter (α := α) β) : List β :=
it.toIterM.toListRev.run

View file

@ -226,7 +226,7 @@ any element emitted by the iterator {name}`it`.
{lit}`O(|xs|)`. Short-circuits upon encountering the first match. The elements in {name}`it` are
examined in order of iteration.
-/
@[inline]
@[inline, cbv_opaque]
def Iter.any {α β : Type w}
[Iterator α Id β] [IteratorLoop α Id Id]
(p : β → Bool) (it : Iter (α := α) β) : Bool :=
@ -292,7 +292,7 @@ all element emitted by the iterator {name}`it`.
{lit}`O(|xs|)`. Short-circuits upon encountering the first match. The elements in {name}`it` are
examined in order of iteration.
-/
@[inline]
@[inline, cbv_opaque]
def Iter.all {α β : Type w}
[Iterator α Id β] [IteratorLoop α Id Id]
(p : β → Bool) (it : Iter (α := α) β) : Bool :=
@ -644,7 +644,7 @@ Examples:
* `[7, 6].iter.first? = some 7`
* `[].iter.first? = none`
-/
@[inline]
@[inline, cbv_opaque]
def Iter.first? {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
(it : Iter (α := α) β) : Option β :=
it.toIterM.first?.run

View file

@ -110,6 +110,7 @@ theorem Iter.reverse_toListRev_ensureTermination [Iterator α Id β] [Finite α
it.ensureTermination.toListRev.reverse = it.toList := by
simp
@[cbv_eval]
theorem Iter.toListRev_eq {α β} [Iterator α Id β] [Finite α Id]
{it : Iter (α := α) β} :
it.toListRev = it.toList.reverse := by

View file

@ -637,6 +637,7 @@ theorem Iter.any_eq_forIn {α β : Type w} [Iterator α Id β]
return .yield false)).run := by
simp [any_eq_anyM, anyM_eq_forIn]
@[cbv_eval ←]
theorem Iter.any_toList {α β : Type w} [Iterator α Id β]
[Finite α Id] [IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
{it : Iter (α := α) β} {p : β → Bool} :
@ -727,6 +728,7 @@ theorem Iter.all_eq_forIn {α β : Type w} [Iterator α Id β]
return .done false)).run := by
simp [all_eq_allM, allM_eq_forIn]
@[cbv_eval ←]
theorem Iter.all_toList {α β : Type w} [Iterator α Id β]
[Finite α Id] [IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
{it : Iter (α := α) β} {p : β → Bool} :
@ -954,7 +956,7 @@ theorem Iter.first?_eq_match_step {α β : Type w} [Iterator α Id β] [Iterator
generalize it.toIterM.step.run.inflate = s
rcases s with ⟨_|_|_, _⟩ <;> simp [Iter.first?_eq_first?_toIterM]
@[simp, grind =]
@[simp, grind =, cbv_eval ←]
theorem Iter.head?_toList {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
[Finite α Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} :
it.toList.head? = it.first? := by

View file

@ -193,6 +193,7 @@ public theorem Array.toSubarray_eq_toSubarray_of_min_eq_min {xs : Array α}
simp [*]; omega
· simp
@[cbv_eval]
public theorem Array.toSubarray_eq_min {xs : Array α} {lo hi : Nat} :
xs.toSubarray lo hi = ⟨⟨xs, min lo (min hi xs.size), min hi xs.size, Nat.min_le_right _ _,
Nat.min_le_right _ _⟩⟩ := by

View file

@ -26,7 +26,7 @@ The iterator yields the elements of the map in order and then terminates.
* `Finite` instance: always
* `Productive` instance: always
-/
@[inline]
@[inline, cbv_opaque]
public def iter {α : Type u} {β : Type v}
{cmp : αα → Ordering} (m : TreeMap α β cmp) :=
(m.inner.iter.map fun e => (e.1, e.2) : Iter (α × β))
@ -63,7 +63,7 @@ public def valuesIter {α : Type u} {β : Type u} {cmp : αα → Ordering}
(m : TreeMap α β cmp) :=
m.inner.valuesIter
@[simp]
@[simp, cbv_eval]
public theorem toList_iter {cmp : αα → Ordering} (m : TreeMap α β cmp) :
m.iter.toList = m.toList := by
simp only [iter, Iter.toList_map, DTreeMap.toList_iter, DTreeMap.toList,

View file

@ -0,0 +1,15 @@
module
import Std
open Std
def isPalindrome (s : String) : Bool :=
s.chars.zip s.revChars |>.all (fun p => p.1 == p.2)
example : isPalindrome "" = true := by cbv
example : isPalindrome "aba" = true := by cbv
example : isPalindrome "aaaaa" = true := by cbv
example : isPalindrome "zbcd" = false := by cbv
example : isPalindrome "xywyx" = true := by cbv
example : isPalindrome "xywyz" = false := by cbv
example : isPalindrome "xywxz" = false := by cbv

View file

@ -0,0 +1,46 @@
module
public import Std
public import Init.Data.Iterators.Lemmas.Basic
open Std
public section
def frequencies (xs : List Nat) : TreeMap Nat Nat (fun a b => compare b a) :=
xs.foldl (init := ∅) (fun freq (x : Nat) => freq.alter x (fun v? => some (v?.getD 0 + 1)))
def search (xs : List Nat) : Int :=
let frequencies := frequencies xs
let kv := frequencies.iter
|>.filter (fun (k, v) => 0 < k ∧ k ≤ v)
|>.map (fun (k, _) => k)
|>.first?
kv.getD (-1)
/-! ## Tests -/
example : search [5, 5, 5, 5, 1] = 1 := by cbv
example : search [4, 1, 4, 1, 4, 4] = 4 := by cbv
example : search [3, 3] = -1 := by cbv
example : search [8, 8, 8, 8, 8, 8, 8, 8] = 8 := by cbv
example : search [2, 3, 3, 2, 2] = 2 := by cbv
example : search [2, 7, 8, 8, 4, 8, 7, 3, 9, 6, 5, 10, 4, 3, 6, 7, 1, 7, 4, 10, 8, 1] = 1 := by cbv
example : search [3, 2, 8, 2] = 2 := by cbv
example : search [6, 7, 1, 8, 8, 10, 5, 8, 5, 3, 10] = 1 := by cbv
example : search [8, 8, 3, 6, 5, 6, 4] = -1 := by cbv
example : search [6, 9, 6, 7, 1, 4, 7, 1, 8, 8, 9, 8, 10, 10, 8, 4, 10, 4, 10, 1, 2, 9, 5, 7, 9] = 1 := by cbv
example : search [1, 9, 10, 1, 3] = 1 := by cbv
example : search [6, 9, 7, 5, 8, 7, 5, 3, 7, 5, 10, 10, 3, 6, 10, 2, 8, 6, 5, 4, 9, 5, 3, 10] = 5 := by cbv
example : search [1] = 1 := by cbv
example : search [8, 8, 10, 6, 4, 3, 5, 8, 2, 4, 2, 8, 4, 6, 10, 4, 2, 1, 10, 2, 1, 1, 5] = 4 := by cbv
example : search [2, 10, 4, 8, 2, 10, 5, 1, 2, 9, 5, 5, 6, 3, 8, 6, 4, 10] = 2 := by cbv
example : search [1, 6, 10, 1, 6, 9, 10, 8, 6, 8, 7, 3] = 1 := by cbv
example : search [9, 2, 4, 1, 5, 1, 5, 2, 5, 7, 7, 7, 3, 10, 1, 5, 4, 2, 8, 4, 1, 9, 10, 7, 10, 2, 8, 10, 9, 4] = 4 := by cbv
example : search [2, 6, 4, 2, 8, 7, 5, 6, 4, 10, 4, 6, 3, 7, 8, 8, 3, 1, 4, 2, 2, 10, 7] = 4 := by cbv
example : search [9, 8, 6, 10, 2, 6, 10, 2, 7, 8, 10, 3, 8, 2, 6, 2, 3, 1] = 2 := by cbv
example : search [5, 5, 3, 9, 5, 6, 3, 2, 8, 5, 6, 10, 10, 6, 8, 4, 10, 7, 7, 10, 8] = -1 := by cbv
example : search [10] = -1 := by cbv
example : search [9, 7, 7, 2, 4, 7, 2, 10, 9, 7, 5, 7, 2] = 2 := by cbv
example : search [5, 4, 10, 2, 1, 1, 10, 3, 6, 1, 8] = 1 := by cbv
example : search [7, 9, 9, 9, 3, 4, 1, 5, 9, 1, 2, 1, 1, 10, 7, 5, 6, 7, 6, 7, 7, 6] = 1 := by cbv
example : search [3, 10, 10, 9, 2] = -1 := by cbv

View file

@ -0,0 +1,9 @@
module
public import Init.Data.Array.MinMax
def maxElement (xs : Array Int) : Int :=
xs.max?.getD 0
example : maxElement #[1, 2, 3] = 3 := by cbv
example : maxElement #[5, 3, -5, 2, -3, 3, 9, 0, 124, 1, -10] = 124 := by cbv

View file

@ -0,0 +1,29 @@
module
import Std.Data.Iterators
@[grind =]
def makeAPile₁ (n : Nat) : List Nat :=
(*...n).iter.map (n + 2 * ·) |>.toList
@[grind =]
def makeAPile₂ (n : Nat) : List Nat :=
(*...n).iter.map (fun i => n + 2 * (n - 1 - i)) |>.toListRev
example : makeAPile₁ 0 = [] := by cbv
example : makeAPile₁ 1 = [1] := by cbv
example : makeAPile₁ 2 = [2, 4] := by cbv
example : makeAPile₁ 3 = [3, 5, 7] := by cbv
example : makeAPile₁ 4 = [4, 6, 8, 10] := by cbv
example : makeAPile₁ 5 = [5, 7, 9, 11, 13] := by cbv
example : makeAPile₁ 6 = [6, 8, 10, 12, 14, 16] := by cbv
example : makeAPile₁ 8 = [8, 10, 12, 14, 16, 18, 20, 22] := by cbv
example : makeAPile₂ 0 = [] := by cbv
example : makeAPile₂ 1 = [1] := by cbv
example : makeAPile₂ 2 = [2, 4] := by cbv
example : makeAPile₂ 3 = [3, 5, 7] := by cbv
example : makeAPile₂ 4 = [4, 6, 8, 10] := by cbv
example : makeAPile₂ 5 = [5, 7, 9, 11, 13] := by cbv
example : makeAPile₂ 6 = [6, 8, 10, 12, 14, 16] := by cbv
example : makeAPile₂ 8 = [8, 10, 12, 14, 16, 18, 20, 22] := by cbv

View file

@ -0,0 +1,38 @@
module
import Std
open Std Std.Do
set_option mvcgen.warning false
def isSorted (xs : Array Nat) : Bool := Id.run do
if h : xs.size > 0 then
let mut last := xs[0]
let mut repeated := false
for x in xs[1...*] do
match compare last x with
| .lt =>
last := x
repeated := false
| .eq =>
if repeated then
return false
else
repeated := true
| .gt =>
return false
return true
example : isSorted #[5] = true := by cbv
example : isSorted #[1, 2, 3, 4, 5] = true := by cbv
example : isSorted #[1, 3, 2, 4, 5] = false := by cbv
example : isSorted #[1, 2, 3, 4, 5, 6] = true := by cbv
example : isSorted #[1, 2, 3, 4, 5, 6, 7] = true := by cbv
example : isSorted #[1, 3, 2, 4, 5, 6, 7] = false := by cbv
example : isSorted #[] = true := by cbv
example : isSorted #[1] = true := by cbv
example : isSorted #[3, 2, 1] = false := by cbv
example : isSorted #[1, 2, 2, 2, 3, 4] = false := by cbv
example : isSorted #[1, 2, 3, 3, 3, 4] = false := by cbv
example : isSorted #[1, 2, 2, 3, 3, 4] = true := by cbv
example : isSorted #[1, 2, 3, 4] = true := by cbv