From cd62b8cd80baf3a83356333ed45f52dfced70225 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Thu, 30 Jan 2025 11:23:18 +0100 Subject: [PATCH] refactor: WF.Eqns: rewrite fix without duplicating F (#6859) This PR changes how WF.Eqns unfolds the fixpoint. Instead of delta'ing until we have `fix`, and then blindly applying `fix_eq`, we delta one step less and preserve the function on the right hand side. This leads to smaller terms in the next step, so easier to debug, possibly faster, possibly more robust. --- src/Lean/Elab/PreDefinition/WF/Eqns.lean | 48 ++++++++++++------------ tests/lean/run/simpDiag.lean | 6 +++ 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/Lean/Elab/PreDefinition/WF/Eqns.lean b/src/Lean/Elab/PreDefinition/WF/Eqns.lean index 43c17b93b5..ba3a0a2f52 100644 --- a/src/Lean/Elab/PreDefinition/WF/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/WF/Eqns.lean @@ -10,7 +10,6 @@ import Lean.Elab.PreDefinition.Basic import Lean.Elab.PreDefinition.Eqns import Lean.Meta.ArgsPacker.Basic import Init.Data.Array.Basic -import Init.Internal.Order.Basic namespace Lean.Elab.WF open Meta @@ -23,35 +22,33 @@ structure EqnInfo extends EqnInfoCore where argsPacker : ArgsPacker deriving Inhabited -private partial def deltaLHSUntilFix (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do - let target ← mvarId.getType' - let some (_, lhs, _) := target.eq? | throwTacticEx `deltaLHSUntilFix mvarId "equality expected" - if lhs.isAppOf ``WellFounded.fix then - return mvarId - else if lhs.isAppOf ``Order.fix then - return mvarId - else - deltaLHSUntilFix (← deltaLHS mvarId) - private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do let target ← mvarId.getType' let some (_, lhs, rhs) := target.eq? | unreachable! - let h ← - if lhs.isAppOf ``WellFounded.fix then - pure <| mkAppN (mkConst ``WellFounded.fix_eq lhs.getAppFn.constLevels!) lhs.getAppArgs - else if lhs.isAppOf ``Order.fix then - let x := lhs.getAppArgs.back! - let args := lhs.getAppArgs.pop - mkAppM ``congrFun #[mkAppN (mkConst ``Order.fix_eq lhs.getAppFn.constLevels!) args, x] - else - throwTacticEx `rwFixEq mvarId "expected fixed-point application" - let some (_, _, lhsNew) := (← inferType h).eq? | unreachable! + + -- lhs should be an application of the declNameNonrec, which unfolds to an + -- application of fix in one step + let some lhs' ← delta? lhs | throwError "rwFixEq: cannot delta-reduce {lhs}" + let_expr WellFounded.fix _α _C _r _hwf F x := lhs' + | throwTacticEx `rwFixEq mvarId "expected saturated fixed-point application in {lhs'}" + let h := mkAppN (mkConst ``WellFounded.fix_eq lhs'.getAppFn.constLevels!) lhs'.getAppArgs + + -- We used to just rewrite with `fix_eq` and continue with whatever RHS that produces, but that + -- would include more copies of `fix` resulting in large and confusing terms. + -- Instead we manually construct the new term in terms of the current functions, + -- which should be headed by the `declNameNonRec`, and should be defeq to the expected type + + -- if lhs == e x and lhs' == fix .., then lhsNew := e x = F x (fun y _ => e y) + let ftype := (← inferType (mkApp F x)).bindingDomain! + let f' ← forallBoundedTelescope ftype (some 2) fun ys _ => do + mkLambdaFVars ys (.app lhs.appFn! ys[0]!) + let lhsNew := mkApp2 F x f' let targetNew ← mkEq lhsNew rhs let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew mvarId.assign (← mkEqTrans h mvarNew) return mvarNew.mvarId! -private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do +private partial def mkProof (declName declNameNonRec : Name) (type : Expr) : MetaM Expr := do trace[Elab.definition.wf.eqns] "proving: {type}" withNewMCtxDepth do let main ← mkFreshExprSyntheticOpaqueMVar type @@ -83,7 +80,10 @@ private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do -- LHS (introduced in 096e4eb), but it seems that code path was never used, -- so #3133 removed it again (and can be recovered from there if this was premature). throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}" - go (← rwFixEq (← deltaLHSUntilFix mvarId)) + + let mvarId ← if declName != declNameNonRec then deltaLHS mvarId else pure mvarId + let mvarId ← rwFixEq mvarId + go mvarId instantiateMVars main def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) := @@ -101,7 +101,7 @@ def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) := trace[Elab.definition.wf.eqns] "{eqnTypes[i]}" let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1) thmNames := thmNames.push name - let value ← mkProof declName type + let value ← mkProof declName info.declNameNonRec type let (type, value) ← removeUnusedEqnHypotheses type value addDecl <| Declaration.thmDecl { name, type, value diff --git a/tests/lean/run/simpDiag.lean b/tests/lean/run/simpDiag.lean index 20feb4999d..f31dc67feb 100644 --- a/tests/lean/run/simpDiag.lean +++ b/tests/lean/run/simpDiag.lean @@ -42,6 +42,12 @@ info: [simp] Diagnostics [simp] ack.eq_1 ↦ 768, succeeded: 768 use `set_option diagnostics.threshold ` to control threshold for reporting counters --- +info: [diag] Diagnostics + [kernel] unfolded declarations (max: 29, num: 2): + [kernel] Nat.casesOn ↦ 29 + [kernel] Nat.rec ↦ 29 + use `set_option diagnostics.threshold ` to control threshold for reporting counters +--- error: tactic 'simp' failed, nested error: maximum recursion depth has been reached use `set_option maxRecDepth ` to increase limit