From deb3299263b8d43324d60b82df262f7ca4a97b7c Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Sat, 1 Feb 2025 20:04:05 +0100 Subject: [PATCH] refactor: simpMatch to not etaStruct (#6901) This PR changes the `simpMatch` function, used inside the equation generator for WF-rec functions, to not do eta-expansion. This makes the process a bit more robust and disciplined, and avoids removing match-statements (and introduce projections and dependencies) that we'd rather split instead. Also adds more tracing to the equational theorem generator. Extracted from #6898. --- src/Lean/Elab/PreDefinition/WF/Eqns.lean | 13 ++++- src/Lean/Meta/Tactic/Simp/SimpTheorems.lean | 2 +- src/Lean/Meta/Tactic/Split.lean | 2 +- tests/lean/run/prefixTableStep.lean | 41 +++++++++++++++ tests/lean/run/simpMatchEta.lean | 58 +++++++++++++++++++++ 5 files changed, 112 insertions(+), 4 deletions(-) create mode 100644 tests/lean/run/prefixTableStep.lean create mode 100644 tests/lean/run/simpMatchEta.lean diff --git a/src/Lean/Elab/PreDefinition/WF/Eqns.lean b/src/Lean/Elab/PreDefinition/WF/Eqns.lean index ba3a0a2f52..332dac8a16 100644 --- a/src/Lean/Elab/PreDefinition/WF/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/WF/Eqns.lean @@ -56,24 +56,33 @@ private partial def mkProof (declName declNameNonRec : Name) (type : Expr) : Met let rec go (mvarId : MVarId) : MetaM Unit := do trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}" if ← withAtLeastTransparency .all (tryURefl mvarId) then + trace[Elab.definition.wf.eqns] "refl!" return () else if (← tryContradiction mvarId) then + trace[Elab.definition.wf.eqns] "contradiction!" return () else if let some mvarId ← simpMatch? mvarId then + trace[Elab.definition.wf.eqns] "simpMatch!" go mvarId else if let some mvarId ← simpIf? mvarId then + trace[Elab.definition.wf.eqns] "simpIf!" go mvarId else if let some mvarId ← whnfReducibleLHS? mvarId then + trace[Elab.definition.wf.eqns] "whnfReducibleLHS!" go mvarId else - let ctx ← Simp.mkContext (config := { dsimp := false }) + let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none }) match (← simpTargetStar mvarId ctx (simprocs := {})).1 with | TacticResultCNM.closed => return () - | TacticResultCNM.modified mvarId => go mvarId + | TacticResultCNM.modified mvarId => + trace[Elab.definition.wf.eqns] "simp only!" + go mvarId | TacticResultCNM.noChange => if let some mvarIds ← casesOnStuckLHS? mvarId then + trace[Elab.definition.wf.eqns] "case split into {mvarIds.size} goals" mvarIds.forM go else if let some mvarIds ← splitTarget? mvarId then + trace[Elab.definition.wf.eqns] "splitTarget into {mvarIds.length} goals" mvarIds.forM go else -- At some point in the past, we looked for occurrences of Wf.fix to fold on the diff --git a/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean b/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean index aafe14626f..9d5e9fb39c 100644 --- a/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean +++ b/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean @@ -170,7 +170,7 @@ instance : ToFormat SimpTheorem where def ppOrigin [Monad m] [MonadEnv m] [MonadError m] : Origin → m MessageData | .decl n post inv => do - let r := MessageData.ofConst (← mkConstWithLevelParams n) + let r := MessageData.ofConstName n match post, inv with | true, true => return m!"← {r}" | true, false => return r diff --git a/src/Lean/Meta/Tactic/Split.lean b/src/Lean/Meta/Tactic/Split.lean index c9057df013..a9525f1e4f 100644 --- a/src/Lean/Meta/Tactic/Split.lean +++ b/src/Lean/Meta/Tactic/Split.lean @@ -17,7 +17,7 @@ def getSimpMatchContext : MetaM Simp.Context := do Simp.mkContext (simpTheorems := {}) (congrTheorems := (← getSimpCongrTheorems)) - (config := { Simp.neutralConfig with dsimp := false }) + (config := { Simp.neutralConfig with dsimp := false, etaStruct := .none }) def simpMatch (e : Expr) : MetaM Simp.Result := do let discharge? ← SplitIf.mkDischarge? diff --git a/tests/lean/run/prefixTableStep.lean b/tests/lean/run/prefixTableStep.lean new file mode 100644 index 0000000000..27daa19025 --- /dev/null +++ b/tests/lean/run/prefixTableStep.lean @@ -0,0 +1,41 @@ +/-! Equational theorem generation regression test.-/ + +structure PrefixTable (α : Type _) extends Array (α × Nat) where + /-- Validity condition to help with termination proofs -/ + valid : (h : i < toArray.size) → toArray[i].2 ≤ i + +def PrefixTable.step [BEq α] (t : PrefixTable α) (x : α) (kf : Fin (t.size+1)) : Fin (t.size+1) := + match kf with + | ⟨k, hk⟩ => + let cont := fun () => + match k with + | 0 => ⟨0, Nat.zero_lt_succ _⟩ + | k + 1 => + have h2 : k < t.size := Nat.lt_of_succ_lt_succ hk + let k' := t.toArray[k].2 + have hk' : k' < k + 1 := Nat.lt_succ_of_le (t.valid h2) + step t x ⟨k', Nat.lt_trans hk' hk⟩ + if hsz : k < t.size then + if x == t.toArray[k].1 then + ⟨k+1, Nat.succ_lt_succ hsz⟩ + else cont () + else cont () +termination_by kf.val + +/-- +info: PrefixTable.step.eq_def.{u_1} {α : Type u_1} [BEq α] (t : PrefixTable α) (x : α) (kf : Fin (t.size + 1)) : + t.step x kf = + match kf with + | ⟨k, hk⟩ => + let cont := fun x_1 => + match k, hk with + | 0, hk => ⟨0, ⋯⟩ + | k.succ, hk => + let_fun h2 := ⋯; + let k' := t.toArray[k].snd; + let_fun hk' := ⋯; + t.step x ⟨k', ⋯⟩; + if hsz : k < t.size then if (x == t.toArray[k].fst) = true then ⟨k + 1, ⋯⟩ else cont () else cont () +-/ +#guard_msgs in +#check PrefixTable.step.eq_def diff --git a/tests/lean/run/simpMatchEta.lean b/tests/lean/run/simpMatchEta.lean new file mode 100644 index 0000000000..465383784e --- /dev/null +++ b/tests/lean/run/simpMatchEta.lean @@ -0,0 +1,58 @@ +import Lean + +def foo (n : Nat) (f : Fin n) := match f with | ⟨k, _hk⟩ => if k == 0 then true else false + +def thm : foo n f = (if f.val == 0 then true else false) := by simp [foo] + +-- NB: equational theorem only applies if motive is manifest constructor +/-- +info: foo.match_1.eq_1.{u_1} (n : Nat) (motive : Fin n → Sort u_1) (k : Nat) (_hk : k < n) + (h_1 : (k : Nat) → (_hk : k < n) → motive ⟨k, _hk⟩) : + (match ⟨k, _hk⟩ with + | ⟨k, _hk⟩ => h_1 k _hk) = + h_1 k _hk +-/ +#guard_msgs in +#check foo.match_1.eq_1 + +open Lean Meta Elab Term + +elab "simpMatch" t:term : command => do + Command.runTermElabM fun _ => do + withDeclName `_simpMatch do + let e ← elabTerm t none + synthesizeSyntheticMVarsNoPostponing (ignoreStuckTC := false) + let e' ← instantiateMVars e + let r ← Split.simpMatch e' + logInfo m!"{indentExpr e}\n==>{indentExpr r.expr}" + + +-- This should simplify + +/-- +info: + fun n f => + match ⟨↑f, ⋯⟩ with + | ⟨k, _hk⟩ => if (k == 0) = true then true else false +==> + fun n f => if (↑f == 0) = true then true else false +-/ +#guard_msgs in +simpMatch + fun (n : Nat) (f : Fin n) => (match Fin.mk f.val f.isLt with | ⟨k, _hk⟩ => if k == 0 then true else false) + +-- But this should not + +/-- +info: + fun n f => + match f with + | ⟨k, _hk⟩ => if (k == 0) = true then true else false +==> + fun n f => + match f with + | ⟨k, _hk⟩ => if (k == 0) = true then true else false +-/ +#guard_msgs in +simpMatch + fun (n : Nat) (f : Fin n) => (match f with | ⟨k, _hk⟩ => if k == 0 then true else false)