feat: well-founded recursion: opaque well-foundedness proofs (#5182)

This PR makes functions defined by well-founded recursion use an
`opaque` well-founded proof by default. This reliably prevents kernel
reduction of such definitions and proofs, which tends to be
prohibitively slow (fixes #2171), and which regularly causes
hard-to-debug kernel type-checking failures. This changes renders
`unseal` ineffective for such definitions. To avoid the opaque proof,
annotate the function definition with `@[semireducible]`.
This commit is contained in:
Joachim Breitner 2025-03-19 10:21:04 +01:00 committed by GitHub
parent bf241f9e86
commit 41a2e9af19
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 109 additions and 74 deletions

View file

@ -2280,6 +2280,12 @@ So, you are mainly losing the capability of type checking your development using
-/
axiom ofReduceNat (a b : Nat) (h : reduceNat a = b) : a = b
/--
The term `opaqueId x` will not be reduced by the kernel.
-/
opaque opaqueId {α : Sort u} (x : α) : α := x
end Lean
@[simp] theorem ge_iff_le [LE α] {x y : α} : x ≥ y ↔ y ≤ x := Iff.rfl

View file

@ -570,7 +570,9 @@ state, the right approach is usually the tactic `simp [Array.unattach, -Array.ma
-/
def unattach {α : Type _} {p : α → Prop} (xs : Array { x // p x }) : Array α := xs.map (·.val)
@[simp] theorem unattach_nil {p : α → Prop} : (#[] : Array { x // p x }).unattach = #[] := rfl
@[simp] theorem unattach_nil {p : α → Prop} : (#[] : Array { x // p x }).unattach = #[] := by
simp [unattach]
@[simp] theorem unattach_push {p : α → Prop} {a : { x // p x }} {xs : Array { x // p x }} :
(xs.push a).unattach = xs.unattach.push a.1 := by
simp only [unattach, Array.map_push]

View file

@ -21,7 +21,7 @@ open Nat
/-! ### eraseP -/
@[simp] theorem eraseP_empty : #[].eraseP p = #[] := rfl
@[simp] theorem eraseP_empty : #[].eraseP p = #[] := by simp
theorem eraseP_of_forall_mem_not {xs : Array α} (h : ∀ a, a ∈ xs → ¬p a) : xs.eraseP p = xs := by
rcases xs with ⟨xs⟩

View file

@ -408,7 +408,7 @@ theorem false_of_mem_extract_findIdx {xs : Array α} {p : α → Bool} (h : x
/-! ### findIdx? -/
@[simp] theorem findIdx?_empty : (#[] : Array α).findIdx? p = none := rfl
@[simp] theorem findIdx?_empty : (#[] : Array α).findIdx? p = none := by simp
@[simp]
theorem findIdx?_eq_none_iff {xs : Array α} {p : α → Bool} :
@ -526,7 +526,7 @@ theorem findIdx?_eq_some_le_of_findIdx?_eq_some {xs : Array α} {p q : α → Bo
/-! ### findFinIdx? -/
@[simp] theorem findFinIdx?_empty {p : α → Bool} : findFinIdx? p #[] = none := rfl
@[simp] theorem findFinIdx?_empty {p : α → Bool} : findFinIdx? p #[] = none := by simp
-- We can't mark this as a `@[congr]` lemma since the head of the RHS is not `findFinIdx?`.
theorem findFinIdx?_congr {p : α → Bool} {xs ys : Array α} (w : xs = ys) :
@ -595,7 +595,7 @@ The verification API for `idxOf?` is still incomplete.
The lemmas below should be made consistent with those for `findIdx?` (and proved using them).
-/
@[simp] theorem idxOf?_empty [BEq α] : (#[] : Array α).idxOf? a = none := rfl
@[simp] theorem idxOf?_empty [BEq α] : (#[] : Array α).idxOf? a = none := by simp
@[simp] theorem idxOf?_eq_none_iff [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
xs.idxOf? a = none ↔ a ∉ xs := by
@ -612,7 +612,7 @@ theorem idxOf?_eq_map_finIdxOf?_val [BEq α] {xs : Array α} {a : α} :
xs.idxOf? a = (xs.finIdxOf? a).map (·.val) := by
simp [idxOf?, finIdxOf?, findIdx?_eq_map_findFinIdx?_val]
@[simp] theorem finIdxOf?_empty [BEq α] : (#[] : Array α).finIdxOf? a = none := rfl
@[simp] theorem finIdxOf?_empty [BEq α] : (#[] : Array α).finIdxOf? a = none := by simp
@[simp] theorem finIdxOf?_eq_none_iff [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
xs.finIdxOf? a = none ↔ a ∉ xs := by

View file

@ -554,7 +554,7 @@ theorem anyM_loop_cons [Monad m] (p : α → m Bool) (a : α) (as : List α) (st
@[simp] theorem anyM_toList [Monad m] (p : α → m Bool) (as : Array α) :
as.toList.anyM p = as.anyM p :=
match as with
| ⟨[]⟩ => rfl
| ⟨[]⟩ => by simp [anyM, anyM.loop]
| ⟨a :: as⟩ => by
simp only [List.anyM, anyM, List.size_toArray, List.length_cons, Nat.le_refl, ↓reduceDIte]
rw [anyM.loop, dif_pos (by omega)]
@ -1178,7 +1178,7 @@ theorem map_id' (xs : Array α) : map (fun (a : α) => a) xs = xs := map_id xs
theorem map_id'' {f : αα} (h : ∀ x, f x = x) (xs : Array α) : map f xs = xs := by
simp [show f = id from funext h]
theorem map_singleton (f : α → β) (a : α) : map f #[a] = #[f a] := rfl
theorem map_singleton (f : α → β) (a : α) : map f #[a] = #[f a] := by simp
-- We use a lower priority here as there are more specific lemmas in downstream libraries
-- which should be able to fire first.

View file

@ -16,7 +16,8 @@ set_option linter.indexVariables true -- Enforce naming conventions for index va
namespace Array
@[simp] theorem ofFn_zero (f : Fin 0 → α) : ofFn f = #[] := rfl
@[simp] theorem ofFn_zero (f : Fin 0 → α) : ofFn f = #[] := by
simp [ofFn, ofFn.go]
theorem ofFn_succ (f : Fin (n+1) → α) :
ofFn f = (ofFn (fun (i : Fin n) => f i.castSucc)).push (f ⟨n, by omega⟩) := by

View file

@ -39,7 +39,8 @@ theorem range'_ne_empty_iff (s : Nat) {n step : Nat} : range' s n step ≠ #[]
@[simp] theorem range'_zero : range' s 0 step = #[] := by
simp
@[simp] theorem range'_one {s step : Nat} : range' s 1 step = #[s] := rfl
@[simp] theorem range'_one {s step : Nat} : range' s 1 step = #[s] := by
simp [range', ofFn, ofFn.go]
@[simp] theorem range'_inj : range' s n = range' s' n' ↔ n = n' ∧ (n = 0 s = s') := by
rw [← toList_inj]
@ -77,7 +78,7 @@ theorem range'_append (s m n step : Nat) :
range' s m ++ range' (s + m) n = range' s (m + n) := by simpa using range'_append s m n 1
theorem range'_concat (s n : Nat) : range' s (n + 1) step = range' s n step ++ #[s + step * n] := by
exact (range'_append s n 1 step).symm
simpa using (range'_append s n 1 step).symm
theorem range'_1_concat (s n : Nat) : range' s (n + 1) = range' s n ++ #[s + n] := by
simp [range'_concat]

View file

@ -141,7 +141,9 @@ theorem foldrM_loop [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x
| zero =>
rw [foldrM_loop_zero, foldrM_loop_succ, pure_bind]
conv => rhs; rw [←bind_pure (f 0 x)]
congr; funext
congr
funext
try simp only [foldrM.loop] -- the try makes this proof work with and without opaque wf rec
| succ i ih =>
rw [foldrM_loop_succ, foldrM_loop_succ, bind_assoc]
congr; funext; exact ih ..

View file

@ -468,7 +468,8 @@ If not, usually the right approach is `simp [Vector.unattach, -Vector.map_subtyp
-/
def unattach {α : Type _} {p : α → Prop} (xs : Vector { x // p x } n) : Vector α n := xs.map (·.val)
@[simp] theorem unattach_nil {p : α → Prop} : (#v[] : Vector { x // p x } 0).unattach = #v[] := rfl
@[simp] theorem unattach_nil {p : α → Prop} : (#v[] : Vector { x // p x } 0).unattach = #v[] := by simp
@[simp] theorem unattach_push {p : α → Prop} {a : { x // p x }} {xs : Vector { x // p x } n} :
(xs.push a).unattach = xs.unattach.push a.1 := by
simp only [unattach, Vector.map_push]

View file

@ -254,7 +254,7 @@ theorem find?_eq_some_iff_getElem {xs : Vector α n} {p : α → Bool} {b : α}
/-! ### findFinIdx? -/
@[simp] theorem findFinIdx?_empty {p : α → Bool} : findFinIdx? p (#v[] : Vector α 0) = none := rfl
@[simp] theorem findFinIdx?_empty {p : α → Bool} : findFinIdx? p (#v[] : Vector α 0) = none := by simp
@[congr] theorem findFinIdx?_congr {p : α → Bool} {xs : Vector α n} {ys : Vector α n} (w : xs = ys) :
findFinIdx? p xs = findFinIdx? p ys := by

View file

@ -1407,7 +1407,7 @@ theorem map_id' (xs : Vector α n) : map (fun (a : α) => a) xs = xs := map_id x
theorem map_id'' {f : αα} (h : ∀ x, f x = x) (xs : Vector α n) : map f xs = xs := by
simp [show f = id from funext h]
theorem map_singleton (f : α → β) (a : α) : map f #v[a] = #v[f a] := rfl
theorem map_singleton (f : α → β) (a : α) : map f #v[a] = #v[f a] := by simp
-- We use a lower priority here as there are more specific lemmas in downstream libraries
-- which should be able to fire first.

View file

@ -49,7 +49,7 @@ theorem range'_succ (s n step) :
theorem range'_zero : range' s 0 step = #v[] := by
simp
@[simp] theorem range'_one {s step : Nat} : range' s 1 step = #v[s] := rfl
@[simp] theorem range'_one {s step : Nat} : range' s 1 step = #v[s] := by simp
@[simp] theorem range'_inj : range' s n = range' s' n ↔ (n = 0 s = s') := by
rw [← toArray_inj]
@ -76,7 +76,7 @@ theorem range'_append (s m n step : Nat) :
range' s m ++ range' (s + m) n = range' s (m + n) := by simpa using range'_append s m n 1
theorem range'_concat (s n : Nat) : range' s (n + 1) step = range' s n step ++ #v[s + step * n] := by
exact (range'_append s n 1 step).symm
simpa using (range'_append s n 1 step).symm
theorem range'_1_concat (s n : Nat) : range' s (n + 1) = range' s n ++ #v[s + n] := by
simp [range'_concat]

View file

@ -232,7 +232,8 @@ def solveDecreasingGoals (funNames : Array Name) (argsPacker : ArgsPacker) (decr
instantiateMVars value
def mkFix (preDef : PreDefinition) (prefixArgs : Array Expr) (argsPacker : ArgsPacker)
(wfRel : Expr) (funNames : Array Name) (decrTactics : Array (Option DecreasingBy)) : TermElabM Expr := do
(wfRel : Expr) (funNames : Array Name) (decrTactics : Array (Option DecreasingBy))
(opaqueProof : Bool) : TermElabM Expr := do
let type ← instantiateForall preDef.type prefixArgs
let (wfFix, varName) ← forallBoundedTelescope type (some 1) fun x type => do
let x := x[0]!
@ -242,6 +243,7 @@ def mkFix (preDef : PreDefinition) (prefixArgs : Array Expr) (argsPacker : ArgsP
let motive ← mkLambdaFVars #[x] type
let rel := mkProj ``WellFoundedRelation 0 wfRel
let wf := mkProj ``WellFoundedRelation 1 wfRel
let wf ← if opaqueProof then mkAppM `Lean.opaqueId #[wf] else pure wf
let varName ← x.fvarId!.getUserName -- See comment below.
return (mkApp4 (mkConst ``WellFounded.fix [u, v]) α motive rel wf, varName)
forallBoundedTelescope (← whnf (← inferType wfFix)).bindingDomain! (some 2) fun xs _ => do

View file

@ -44,6 +44,11 @@ def wfRecursion (preDefs : Array PreDefinition) (termMeasure?s : Array (Option T
-- No termination_by here, so use GuessLex to infer one
guessLex preDefs unaryPreDef fixedParamPerms argsPacker
let opaqueProof := !
preDefs.any fun preDef =>
preDef.modifiers.attrs.any fun a =>
a.name = `reducible || a.name = `semireducible
let preDefNonRec ← forallBoundedTelescope unaryPreDef.type fixedParamPerms.numFixed fun fixedArgs type => do
let type ← whnfForall type
unless type.isForall do
@ -53,7 +58,7 @@ def wfRecursion (preDefs : Array PreDefinition) (termMeasure?s : Array (Option T
trace[Elab.definition.wf] "wfRel: {wfRel}"
let (value, envNew) ← withoutModifyingEnv' do
addAsAxiom unaryPreDef
let value ← mkFix unaryPreDef fixedArgs argsPacker wfRel (preDefs.map (·.declName)) (preDefs.map (·.termination.decreasingBy?))
let value ← mkFix unaryPreDef fixedArgs argsPacker wfRel (preDefs.map (·.declName)) (preDefs.map (·.termination.decreasingBy?)) opaqueProof
eraseRecAppSyntaxExpr value
/- `mkFix` invokes `decreasing_tactic` which may add auxiliary theorems to the environment. -/
let value ← unfoldDeclsFrom envNew value

View file

@ -1,5 +1,7 @@
#include "util/options.h"
// please update stage0
namespace lean {
options get_default_options() {
options opts;

View file

@ -28,12 +28,11 @@ def onlyZeros : Tree → Prop
| .node [] => True
| .node (x::s) => onlyZeros x ∧ onlyZeros (.node s)
unseal onlyZeros in
/-- Pattern-matching on `OnlyZeros` works despite `below` and `brecOn` not being generated
if we make `onlyZeros` semireducible-/
def toFixPoint : OnlyZeros t → onlyZeros t
| .leaf => rfl
| .node [] _ => True.intro
| .leaf => by simp [onlyZeros]
| .node [] _ => by simp [onlyZeros]
| .node (x::s) (.cons h p) => by
rw [onlyZeros] -- necessary because `onlyZeros` isn't structurally recursive
exact And.intro (toFixPoint h) (toFixPoint (.node s p))

View file

@ -1,4 +1,4 @@
-- NB: well-founded recursion, so irreducible
@[semireducible]
def sorted_from_var [x: LE α] [DecidableRel x.le] (a: Array α) (i: Nat): Bool :=
if h: i + 1 < a.size then
have : i < a.size := Nat.lt_of_succ_lt h
@ -7,6 +7,8 @@ def sorted_from_var [x: LE α] [DecidableRel x.le] (a: Array α) (i: Nat): Bool
true
termination_by a.size - i
attribute [irreducible] sorted_from_var
def check_sorted [x: LE α] [DecidableRel x.le] (a: Array α): Bool :=
sorted_from_var a 0

View file

@ -6,37 +6,28 @@ termination_by a b => (a, b)
/--
info: [diag] Diagnostics
[kernel] unfolded declarations (max: 1193, num: 5):
[kernel] Nat.casesOn ↦ 1193
[kernel] Nat.rec ↦ 1065
[kernel] Eq.ndrec ↦ 973
[kernel] Eq.rec ↦ 973
[kernel] Acc.rec ↦ 754
[kernel] unfolded declarations (max: 147, num: 3):
[kernel] OfNat.ofNat ↦ 147
[kernel] Add.add ↦ 61
[kernel] HAdd.hAdd ↦ 61
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
info: [simp] Diagnostics
[simp] used theorems (max: 59, num: 1):
[simp] ack.eq_3 ↦ 59
[simp] tried theorems (max: 59, num: 1):
[simp] ack.eq_3 ↦ 59, succeeded: 59
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
info: [diag] Diagnostics
[reduction] unfolded declarations (max: 2567, num: 5):
[reduction] Nat.rec ↦ 2567
[reduction] Eq.rec ↦ 1517
[reduction] Acc.rec ↦ 1158
[reduction] Or.rec ↦ 770
[reduction] PSigma.rec ↦ 514
[reduction] unfolded reducible declarations (max: 2337, num: 4):
[reduction] Nat.casesOn ↦ 2337
[reduction] Eq.ndrec ↦ 1307
[reduction] Or.casesOn ↦ 770
[reduction] PSigma.casesOn ↦ 514
[kernel] unfolded declarations (max: 1193, num: 5):
[kernel] Nat.casesOn ↦ 1193
[kernel] Nat.rec ↦ 1065
[kernel] Eq.ndrec ↦ 973
[kernel] Eq.rec ↦ 973
[kernel] Acc.rec ↦ 754
[kernel] unfolded declarations (max: 147, num: 3):
[kernel] OfNat.ofNat ↦ 147
[kernel] Add.add ↦ 61
[kernel] HAdd.hAdd ↦ 61
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
-/
#guard_msgs in
unseal ack in
set_option diagnostics.threshold 500 in
set_option diagnostics.threshold 50 in
set_option diagnostics true in
theorem ex : ack 3 2 = 29 :=
rfl
by simp [ack]

View file

@ -1,3 +1,4 @@
@[semireducible]
def fib (n : Nat) :=
match n with
| 0 | 1 => 1

View file

@ -0,0 +1,12 @@
def g (n : Nat) : Nat :=
if h : n = 0 then
1
else
4 + g (n - 1)
termination_by n
decreasing_by simp_wf; omega
example : g 10000 = id g (id 10000) := rfl
example : id g 10000 = id g (id 10000) := rfl
example : g 10000 + 0 = g (id 10000) + 0 := rfl
example : g 10000 = g (id 10000) := rfl

View file

@ -1,5 +1,6 @@
import Lean
@[semireducible]
def ack : Nat → Nat → Nat
| 0, y => y+1
| x+1, 0 => ack x 1
@ -8,9 +9,7 @@ def ack : Nat → Nat → Nat
set_option maxHeartbeats 500
open Lean Meta
/--
error: (kernel) deterministic timeout
-/
/-- error: (kernel) deterministic timeout -/
#guard_msgs in
run_meta do
let type ← mkEq (← mkAppM ``ack #[mkNatLit 4, mkNatLit 4]) (mkNatLit 100000)

View file

@ -1,33 +1,26 @@
open List MergeSort Internal
-- If we omit the comparator, it is filled by the autoparam `fun a b => a ≤ b`
unseal mergeSort merge in
example : mergeSort [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] :=
rfl
by native_decide
unseal mergeSort merge in
example : mergeSort [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] (· ≤ ·) = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] :=
rfl
by native_decide
unseal mergeSort merge in
example : mergeSort [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] (fun x y => x/10 ≤ y/10) = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] :=
rfl
by native_decide
unseal mergeSortTR.run mergeTR.go in
example : mergeSortTR [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] :=
rfl
by native_decide
unseal mergeSortTR.run mergeTR.go in
example : mergeSortTR [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] (fun x y => x/10 ≤ y/10) = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] :=
rfl
by native_decide
unseal mergeSortTR₂.run mergeTR.go in
example : mergeSortTR₂ [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] :=
rfl
by native_decide
unseal mergeSortTR₂.run mergeTR.go in
example : mergeSortTR₂ [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] (fun x y => x/10 ≤ y/10) = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] :=
rfl
by native_decide
/-!
# Behaviour of mergeSort when the comparator is not provided, but typeclasses are missing.

View file

@ -9,8 +9,7 @@ where
#guard numChars "aαc" == 3
example : numChars "aαc" = 3 := by
rfl'
example : numChars "aαc" = 3 := by native_decide
def numChars2 (s : String) : Nat :=
go s.iter
@ -20,5 +19,4 @@ where
| true => go i.next + 1
| false => 0
example : numChars2 "aαc" = 3 := by
rfl'
example : numChars2 "aαc" = 3 := by native_decide

View file

@ -58,11 +58,29 @@ section Unsealed
unseal foo
example : foo 0 = 0 := rfl
example : foo 0 = 0 := by rfl
-- unsealing works, but does not have the desired effect
/--
error: type mismatch
rfl
has type
?_ = ?_ : Prop
but is expected to have type
foo 0 = 0 : Prop
-/
#guard_msgs in
example : foo 0 = 0 := rfl
/--
error: type mismatch
rfl
has type
?_ = ?_ : Prop
but is expected to have type
foo (n + 1) = foo n : Prop
-/
#guard_msgs in
example : foo (n+1) = foo n := rfl
example : foo (n+1) = foo n := by rfl
end Unsealed
@ -86,6 +104,7 @@ def bar : Nat → Nat
termination_by n => n
-- Once unsealed, the full internals are visible. This allows one to prove, for example
-- an equality like the following
/--
error: type mismatch
@ -98,7 +117,6 @@ but is expected to have type
#guard_msgs in
example : foo = bar := rfl
unseal foo bar in
example : foo = bar := rfl