refactor: change how equations for structural recursion are proved (#10415)

This PR changes the order of steps tried when proving equational
theorems for structural recursion. In order to avoid goals that `split`
cannot handle, avoid unfolding the LHS of the equation to `.brecOn` and
`.rec` until after the RHS has been split into its final cases.

Fixes: #10195
This commit is contained in:
Joachim Breitner 2025-09-17 15:46:45 +02:00 committed by GitHub
parent e74b81169d
commit e532ce95ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 182 additions and 125 deletions

View file

@ -35,26 +35,27 @@ private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
let main ← mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) ← main.mvarId!.intros
unless (← tryURefl mvarId) do -- catch easy cases
go (← deltaLHS mvarId)
go1 mvarId
instantiateMVars main
where
go (mvarId : MVarId) : MetaM Unit := do
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} step:\n{MessageData.ofGoal mvarId}") do
/--
Step 1: Split the function body into its cases, but keeping the LHS intact, because the
`.below`-added `match` statements and the `.rec` can quickly confuse `split`.
-/
go1 (mvarId : MVarId) : MetaM Unit := do
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} go1:\n{MessageData.ofGoal mvarId}") do
if (← tryURefl mvarId) then
trace[Elab.definition.structural.eqns] "tryURefl succeeded"
return ()
else if (← tryContradiction mvarId) then
trace[Elab.definition.structural.eqns] "tryContadiction succeeded"
return ()
else if let some mvarId ← whnfReducibleLHS? mvarId then
trace[Elab.definition.structural.eqns] "whnfReducibleLHS succeeded"
go mvarId
else if let some mvarId ← simpMatch? mvarId then
trace[Elab.definition.structural.eqns] "simpMatch? succeeded"
go mvarId
go1 mvarId
else if let some mvarId ← simpIf? mvarId (useNewSemantics := true) then
trace[Elab.definition.structural.eqns] "simpIf? succeeded"
go mvarId
go1 mvarId
else
let ctx ← Simp.mkContext
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
@ -62,17 +63,50 @@ where
trace[Elab.definition.structural.eqns] "simpTargetStar closed the goal"
| TacticResultCNM.modified mvarId =>
trace[Elab.definition.structural.eqns] "simpTargetStar modified the goal"
go mvarId
go1 mvarId
| TacticResultCNM.noChange =>
if let some mvarId ← deltaRHS? mvarId declName then
trace[Elab.definition.structural.eqns] "deltaRHS? succeeded"
go mvarId
else if let some mvarIds ← casesOnStuckLHS? mvarId then
if let some mvarIds ← casesOnStuckLHS? mvarId then
trace[Elab.definition.structural.eqns] "casesOnStuckLHS? succeeded"
mvarIds.forM go
mvarIds.forM go1
else if let some mvarIds ← splitTarget? mvarId (useNewSemantics := true) then
trace[Elab.definition.structural.eqns] "splitTarget? succeeded"
mvarIds.forM go
mvarIds.forM go1
else
go2 (← deltaLHS mvarId)
/-- Step 2: Unfold the lhs to expose the recursor. -/
go2 (mvarId : MVarId) : MetaM Unit := do
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} go2:\n{MessageData.ofGoal mvarId}") do
if let some mvarId ← whnfReducibleLHS? mvarId then
go2 mvarId
else
go3 mvarId
/-- Step 3: Simplify the match and if statements on the left hand side, until we have rfl. -/
go3 (mvarId : MVarId) : MetaM Unit := do
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} go3:\n{MessageData.ofGoal mvarId}") do
if (← tryURefl mvarId) then
trace[Elab.definition.structural.eqns] "tryURefl succeeded"
return ()
else if (← tryContradiction mvarId) then
trace[Elab.definition.structural.eqns] "tryContadiction succeeded"
return ()
else if let some mvarId ← simpMatch? mvarId then
trace[Elab.definition.structural.eqns] "simpMatch? succeeded"
go3 mvarId
else if let some mvarId ← simpIf? mvarId (useNewSemantics := true) then
trace[Elab.definition.structural.eqns] "simpIf? succeeded"
go3 mvarId
else
let ctx ← Simp.mkContext
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed =>
trace[Elab.definition.structural.eqns] "simpTargetStar closed the goal"
| TacticResultCNM.modified mvarId =>
trace[Elab.definition.structural.eqns] "simpTargetStar modified the goal"
go3 mvarId
| TacticResultCNM.noChange =>
if let some mvarIds ← casesOnStuckLHS? mvarId then
trace[Elab.definition.structural.eqns] "casesOnStuckLHS? succeeded"
mvarIds.forM go3
else
throwError "failed to generate equational theorem for `{.ofConstName declName}`\n{MessageData.ofGoal mvarId}"

View file

@ -1,5 +1,7 @@
#include "util/options.h"
// please update stage0
namespace lean {
options get_default_options() {
options opts;

View file

@ -12,45 +12,39 @@ def optimize : Expr → Expr
/--
error: Failed to realize constant optimize.eq_def:
failed to generate equational theorem for `optimize`
case h_2
e1 : Expr
bop : Unit
i : BitVec 32
heq : optimize e1 = Expr.const i
bop✝ bop_1 : Unit
x : Expr
x_3 :
∀ (i : BitVec 32),
(Expr.rec (fun i => ⟨Expr.const i, PUnit.unit⟩)
(fun op e1 e1_ih =>
⟨match op, e1_ih.1 with
| x, Expr.const i => Expr.op op (Expr.const 0)
| x, x_1 => Expr.const 0,
e1_ih⟩)
e1).1 =
Expr.const i →
False
⊢ Expr.const 0 = Expr.op bop✝ (Expr.const 0)
⊢ (match bop,
(Expr.rec (fun i => ⟨Expr.const i, PUnit.unit⟩)
(fun op e1 e1_ih =>
⟨match op, e1_ih.1 with
| x, Expr.const i => Expr.op op (Expr.const 0)
| x, x_1 => Expr.const 0,
e1_ih⟩)
e1).1 with
| x, Expr.const i => Expr.op bop (Expr.const 0)
| x, x_1 => Expr.const 0) =
Expr.op bop (Expr.const 0)
---
error: Failed to realize constant optimize.eq_def:
failed to generate equational theorem for `optimize`
case h_2
e1 : Expr
bop : Unit
i : BitVec 32
heq : optimize e1 = Expr.const i
bop✝ bop_1 : Unit
x : Expr
x_3 :
∀ (i : BitVec 32),
(Expr.rec (fun i => ⟨Expr.const i, PUnit.unit⟩)
(fun op e1 e1_ih =>
⟨match op, e1_ih.1 with
| x, Expr.const i => Expr.op op (Expr.const 0)
| x, x_1 => Expr.const 0,
e1_ih⟩)
e1).1 =
Expr.const i →
False
⊢ Expr.const 0 = Expr.op bop✝ (Expr.const 0)
⊢ (match bop,
(Expr.rec (fun i => ⟨Expr.const i, PUnit.unit⟩)
(fun op e1 e1_ih =>
⟨match op, e1_ih.1 with
| x, Expr.const i => Expr.op op (Expr.const 0)
| x, x_1 => Expr.const 0,
e1_ih⟩)
e1).1 with
| x, Expr.const i => Expr.op bop (Expr.const 0)
| x, x_1 => Expr.const 0) =
Expr.op bop (Expr.const 0)
---
error: Unknown identifier `optimize.eq_def`
-/
@ -59,24 +53,21 @@ error: Unknown identifier `optimize.eq_def`
/--
error: failed to generate equational theorem for `optimize`
case h_2
e1 : Expr
bop : Unit
i : BitVec 32
heq : optimize e1 = Expr.const i
bop✝ bop_1 : Unit
x : Expr
x_3 :
∀ (i : BitVec 32),
(Expr.rec (fun i => ⟨Expr.const i, PUnit.unit⟩)
(fun op e1 e1_ih =>
⟨match op, e1_ih.1 with
| x, Expr.const i => Expr.op op (Expr.const 0)
| x, x_1 => Expr.const 0,
e1_ih⟩)
e1).1 =
Expr.const i →
False
⊢ Expr.const 0 = Expr.op bop✝ (Expr.const 0)
⊢ (match bop,
(Expr.rec (fun i => ⟨Expr.const i, PUnit.unit⟩)
(fun op e1 e1_ih =>
⟨match op, e1_ih.1 with
| x, Expr.const i => Expr.op op (Expr.const 0)
| x, x_1 => Expr.const 0,
e1_ih⟩)
e1).1 with
| x, Expr.const i => Expr.op bop (Expr.const 0)
| x, x_1 => Expr.const 0) =
Expr.op bop (Expr.const 0)
-/
#guard_msgs in
#print equations optimize

View file

@ -28,51 +28,30 @@ termination_by structural x
#print sig decEqVecPlain.match_1.eq_1
/--
error: Failed to realize constant decEqVecPlain.eq_def:
failed to generate equational theorem for `decEqVecPlain`
case nil.isTrue
α : Type u_1
inst : DecidableEq α
x_1 : Vec α 0
h✝ : Vec.nil.ctorIdx = x_1.ctorIdx
⊢ (match (motive :=
(a : Nat) →
(x x_1 : Vec α a) →
x.ctorIdx = x_1.ctorIdx →
Vec.rec PUnit (fun a {n} a_1 a_ih => ((x_1 : Vec α n) → Decidable (a_1 = x_1)) ×' a_ih) x →
Decidable (x = x_1))
0, Vec.nil, x_1, ⋯ with
| .(0), Vec.nil, Vec.nil, x => fun x => isTrue ⋯
| .(n + 1), Vec.cons a_1 a_2, Vec.cons b b_1, x => fun x_2 =>
if h : a_1 = b then
Eq.rec (motive := fun x x_3 =>
(Vec.cons a_1 a_2).ctorIdx = (Vec.cons x b_1).ctorIdx → Decidable (Vec.cons a_1 a_2 = Vec.cons x b_1))
(fun x =>
if h_2 : a_2 = b_1 then
Eq.rec (motive := fun x x_3 =>
(Vec.cons a_1 a_2).ctorIdx = (Vec.cons a_1 x).ctorIdx →
Decidable (Vec.cons a_1 a_2 = Vec.cons a_1 x))
(fun x => isTrue ⋯) h_2 x
else isFalse ⋯)
⋯ x
else isFalse ⋯)
PUnit.unit =
match 0, Vec.nil, x_1, ⋯ with
info: theorem decEqVecPlain.eq_def.{u_1} : ∀ {α : Type u_1} {a : Nat} [inst : DecidableEq α] (x x_1 : Vec α a),
decEqVecPlain x x_1 =
if h : x.ctorIdx = x_1.ctorIdx then
match a, x, x_1, h with
| .(0), Vec.nil, Vec.nil, x => isTrue ⋯
| .(n + 1), Vec.cons a_1 a_2, Vec.cons b b_1, x =>
if h : a_1 = b then
Eq.rec (motive := fun x x_2 =>
(Vec.cons a_1 a_2).ctorIdx = (Vec.cons x b_1).ctorIdx → Decidable (Vec.cons a_1 a_2 = Vec.cons x b_1))
if h_1 : a_1 = b then
Eq.ndrec (motive := fun b =>
(Vec.cons a_1 a_2).ctorIdx = (Vec.cons b b_1).ctorIdx → Decidable (Vec.cons a_1 a_2 = Vec.cons b b_1))
(fun x =>
have inst_1 := decEqVecPlain a_2 b_1;
if h_2 : a_2 = b_1 then
Eq.rec (motive := fun x x_2 =>
(Vec.cons a_1 a_2).ctorIdx = (Vec.cons a_1 x).ctorIdx → Decidable (Vec.cons a_1 a_2 = Vec.cons a_1 x))
(fun x => isTrue ⋯) h_2 x
Eq.ndrec (motive := fun b_1 =>
(Vec.cons a_1 a_2).ctorIdx = (Vec.cons a_1 b_1).ctorIdx →
have inst := decEqVecPlain a_2 b_1;
Decidable (Vec.cons a_1 a_2 = Vec.cons a_1 b_1))
(fun x =>
have inst := decEqVecPlain a_2 a_2;
isTrue ⋯)
h_2 x
else isFalse ⋯)
⋯ x
h_1 x
else isFalse ⋯
---
error: Unknown constant `decEqVecPlain.eq_def`
else isFalse ⋯
-/
#guard_msgs(pass trace, all) in
#print sig decEqVecPlain.eq_def
@ -94,32 +73,31 @@ termination_by structural x
/--
error: Failed to realize constant foo.eq_def:
failed to generate equational theorem for `foo`
case isTrue
n_1 : Nat
a : I n_1
x' : I (n_1 + 1)
h✝ : P a.cons
⊢ (match (motive := (n : Nat) → (x : I n) → I n → P x → I.rec (fun {n} a a_ih => (I n → R a) ×' a_ih) x → R x)
n_1 + 1, a.cons, x', ⋯ with
| .(n + 1), a_2.cons, a_2'.cons, x => fun x => testSorry (x.1 a_2') h✝)
(I.rec
(fun {n} a a_ih =>
⟨fun x' =>
if h : P a.cons then
(match (motive :=
(n : Nat) → (x : I n) → I n → P x → I.rec (fun {n} a a_ih => (I n → R a) ×' a_ih) x → R x) n + 1,
a.cons, x', ⋯ with
| .(n_2 + 1), a_2.cons, a_2'.cons, x => fun x => testSorry (x.1 a_2') h)
a_ih
else testSorry,
a_ih⟩)
a) =
match n_1 + 1, a.cons, x', ⋯ with
| .(n + 1), a_2.cons, a_2'.cons, x => testSorry (foo a_2 a_2') h✝
---
error: Unknown constant `foo.eq_def`
info: theorem foo.eq_def.{u_1, u_2} : ∀ {n : Nat} (x : I n) (x' : I n),
foo x x' =
if h : P x then
match n, x, x', ⋯ with
| .(n_1 + 1), a_2.cons, a_2'.cons, x_1 => testSorry (foo a_2 a_2') h
else testSorry
-/
#guard_msgs(pass trace, all) in
#print sig foo.eq_def
noncomputable def nondep (x x' : I n) : Bool :=
if h : P x then
match (generalizing := false) x, x', id h with --NB: non-FVar discr
| .cons a_2, .cons a_2', _ => nondep a_2 a_2'
else false
termination_by structural x
/--
info: theorem nondep.eq_def.{u_1, u_2} : ∀ {n : Nat} (x : I n) (x' : I n),
nondep x x' =
if h : P x then
match n, x, x', ⋯ with
| .(n + 1), a_2.cons, a_2'.cons, x => nondep a_2 a_2'
else false
-/
#guard_msgs in
#print sig nondep.eq_def

View file

@ -55,6 +55,58 @@ example (n0 n : Nat) (h : id n0 = n) :
· sorry
· sorry
-- Variant where the discriminant is already a constructor (so substituting the generalized equation
-- may actually help)
/--
error: Tactic `split` failed: Could not split an `if` or `match` expression in the goal
Hint: Use `set_option trace.split.failure true` to display additional diagnostic information
n : Nat
⊢ Fin.last n.succ =
match n.succ with
| 0 => Fin.last 0
| n.succ => Fin.last (n + 1)
-/
#guard_msgs in
example (n : Nat) : Fin.last n.succ = match (motive := ∀ n, Fin (n+1)) Nat.succ n with
| 0 => Fin.last 0
| n + 1 => Fin.last (n + 1) := by
split <;> rfl
-- Manual generalization; the type-incorrect variant done by split
/--
error: Type mismatch
match m with
| 0 => Fin.last 0
| n.succ => Fin.last (n + 1)
has type
Fin (m + 1)
but is expected to have type
Fin (n.succ + 1)
---
error: (kernel) declaration has metavariables '_example'
-/
#guard_msgs in
example (n m : Nat) (h : n.succ = m) : Fin.last n.succ = match (motive := ∀ n, Fin (n+1)) m with
| 0 => Fin.last 0
| n + 1 => Fin.last (n + 1) := sorry
-- What about using ndrec here?
example (n m : Nat) (h : n.succ = m) : Fin.last n.succ =
h.symm.ndrec (motive := fun n => Fin (n + 1))
(match (motive := ∀ n, Fin (n+1)) m with
| 0 => Fin.last 0
| n + 1 => Fin.last (n + 1)) := by
split
· contradiction
· -- the cast is still here!
cases h
-- now the cast can rfl away
rfl
-- Variant with proof-valued discriminant. This works (and always has):