refactor: simpMatch to not etaStruct (#6901)

This PR changes the `simpMatch` function, used inside the equation
generator for WF-rec functions, to not do eta-expansion.

This makes the process a bit more robust and disciplined, and avoids
removing match-statements (and introduce projections and dependencies)
that we'd rather split instead.

Also adds more tracing to the equational theorem generator.

Extracted from #6898.
This commit is contained in:
Joachim Breitner 2025-02-01 20:04:05 +01:00 committed by GitHub
parent 2b0e75748b
commit deb3299263
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 112 additions and 4 deletions

View file

@ -56,24 +56,33 @@ private partial def mkProof (declName declNameNonRec : Name) (type : Expr) : Met
let rec go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
if ← withAtLeastTransparency .all (tryURefl mvarId) then
trace[Elab.definition.wf.eqns] "refl!"
return ()
else if (← tryContradiction mvarId) then
trace[Elab.definition.wf.eqns] "contradiction!"
return ()
else if let some mvarId ← simpMatch? mvarId then
trace[Elab.definition.wf.eqns] "simpMatch!"
go mvarId
else if let some mvarId ← simpIf? mvarId then
trace[Elab.definition.wf.eqns] "simpIf!"
go mvarId
else if let some mvarId ← whnfReducibleLHS? mvarId then
trace[Elab.definition.wf.eqns] "whnfReducibleLHS!"
go mvarId
else
let ctx ← Simp.mkContext (config := { dsimp := false })
let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none })
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.modified mvarId =>
trace[Elab.definition.wf.eqns] "simp only!"
go mvarId
| TacticResultCNM.noChange =>
if let some mvarIds ← casesOnStuckLHS? mvarId then
trace[Elab.definition.wf.eqns] "case split into {mvarIds.size} goals"
mvarIds.forM go
else if let some mvarIds ← splitTarget? mvarId then
trace[Elab.definition.wf.eqns] "splitTarget into {mvarIds.length} goals"
mvarIds.forM go
else
-- At some point in the past, we looked for occurrences of Wf.fix to fold on the

View file

@ -170,7 +170,7 @@ instance : ToFormat SimpTheorem where
def ppOrigin [Monad m] [MonadEnv m] [MonadError m] : Origin → m MessageData
| .decl n post inv => do
let r := MessageData.ofConst (← mkConstWithLevelParams n)
let r := MessageData.ofConstName n
match post, inv with
| true, true => return m!"← {r}"
| true, false => return r

View file

@ -17,7 +17,7 @@ def getSimpMatchContext : MetaM Simp.Context := do
Simp.mkContext
(simpTheorems := {})
(congrTheorems := (← getSimpCongrTheorems))
(config := { Simp.neutralConfig with dsimp := false })
(config := { Simp.neutralConfig with dsimp := false, etaStruct := .none })
def simpMatch (e : Expr) : MetaM Simp.Result := do
let discharge? ← SplitIf.mkDischarge?

View file

@ -0,0 +1,41 @@
/-! Equational theorem generation regression test.-/
structure PrefixTable (α : Type _) extends Array (α × Nat) where
/-- Validity condition to help with termination proofs -/
valid : (h : i < toArray.size) → toArray[i].2 ≤ i
def PrefixTable.step [BEq α] (t : PrefixTable α) (x : α) (kf : Fin (t.size+1)) : Fin (t.size+1) :=
match kf with
| ⟨k, hk⟩ =>
let cont := fun () =>
match k with
| 0 => ⟨0, Nat.zero_lt_succ _⟩
| k + 1 =>
have h2 : k < t.size := Nat.lt_of_succ_lt_succ hk
let k' := t.toArray[k].2
have hk' : k' < k + 1 := Nat.lt_succ_of_le (t.valid h2)
step t x ⟨k', Nat.lt_trans hk' hk⟩
if hsz : k < t.size then
if x == t.toArray[k].1 then
⟨k+1, Nat.succ_lt_succ hsz⟩
else cont ()
else cont ()
termination_by kf.val
/--
info: PrefixTable.step.eq_def.{u_1} {α : Type u_1} [BEq α] (t : PrefixTable α) (x : α) (kf : Fin (t.size + 1)) :
t.step x kf =
match kf with
| ⟨k, hk⟩ =>
let cont := fun x_1 =>
match k, hk with
| 0, hk => ⟨0, ⋯⟩
| k.succ, hk =>
let_fun h2 := ⋯;
let k' := t.toArray[k].snd;
let_fun hk' := ⋯;
t.step x ⟨k', ⋯⟩;
if hsz : k < t.size then if (x == t.toArray[k].fst) = true then ⟨k + 1, ⋯⟩ else cont () else cont ()
-/
#guard_msgs in
#check PrefixTable.step.eq_def

View file

@ -0,0 +1,58 @@
import Lean
def foo (n : Nat) (f : Fin n) := match f with | ⟨k, _hk⟩ => if k == 0 then true else false
def thm : foo n f = (if f.val == 0 then true else false) := by simp [foo]
-- NB: equational theorem only applies if motive is manifest constructor
/--
info: foo.match_1.eq_1.{u_1} (n : Nat) (motive : Fin n → Sort u_1) (k : Nat) (_hk : k < n)
(h_1 : (k : Nat) → (_hk : k < n) → motive ⟨k, _hk⟩) :
(match ⟨k, _hk⟩ with
| ⟨k, _hk⟩ => h_1 k _hk) =
h_1 k _hk
-/
#guard_msgs in
#check foo.match_1.eq_1
open Lean Meta Elab Term
elab "simpMatch" t:term : command => do
Command.runTermElabM fun _ => do
withDeclName `_simpMatch do
let e ← elabTerm t none
synthesizeSyntheticMVarsNoPostponing (ignoreStuckTC := false)
let e' ← instantiateMVars e
let r ← Split.simpMatch e'
logInfo m!"{indentExpr e}\n==>{indentExpr r.expr}"
-- This should simplify
/--
info:
fun n f =>
match ⟨↑f, ⋯⟩ with
| ⟨k, _hk⟩ => if (k == 0) = true then true else false
==>
fun n f => if (↑f == 0) = true then true else false
-/
#guard_msgs in
simpMatch
fun (n : Nat) (f : Fin n) => (match Fin.mk f.val f.isLt with | ⟨k, _hk⟩ => if k == 0 then true else false)
-- But this should not
/--
info:
fun n f =>
match f with
| ⟨k, _hk⟩ => if (k == 0) = true then true else false
==>
fun n f =>
match f with
| ⟨k, _hk⟩ => if (k == 0) = true then true else false
-/
#guard_msgs in
simpMatch
fun (n : Nat) (f : Fin n) => (match f with | ⟨k, _hk⟩ => if k == 0 then true else false)