refactor: WF.Eqns: rewrite fix without duplicating F (#6859)
This PR changes how WF.Eqns unfolds the fixpoint. Instead of delta'ing until we have `fix`, and then blindly applying `fix_eq`, we delta one step less and preserve the function on the right hand side. This leads to smaller terms in the next step, so easier to debug, possibly faster, possibly more robust.
This commit is contained in:
parent
dc445d7af6
commit
cd62b8cd80
2 changed files with 30 additions and 24 deletions
|
|
@ -10,7 +10,6 @@ import Lean.Elab.PreDefinition.Basic
|
|||
import Lean.Elab.PreDefinition.Eqns
|
||||
import Lean.Meta.ArgsPacker.Basic
|
||||
import Init.Data.Array.Basic
|
||||
import Init.Internal.Order.Basic
|
||||
|
||||
namespace Lean.Elab.WF
|
||||
open Meta
|
||||
|
|
@ -23,35 +22,33 @@ structure EqnInfo extends EqnInfoCore where
|
|||
argsPacker : ArgsPacker
|
||||
deriving Inhabited
|
||||
|
||||
private partial def deltaLHSUntilFix (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
|
||||
let target ← mvarId.getType'
|
||||
let some (_, lhs, _) := target.eq? | throwTacticEx `deltaLHSUntilFix mvarId "equality expected"
|
||||
if lhs.isAppOf ``WellFounded.fix then
|
||||
return mvarId
|
||||
else if lhs.isAppOf ``Order.fix then
|
||||
return mvarId
|
||||
else
|
||||
deltaLHSUntilFix (← deltaLHS mvarId)
|
||||
|
||||
private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
|
||||
let target ← mvarId.getType'
|
||||
let some (_, lhs, rhs) := target.eq? | unreachable!
|
||||
let h ←
|
||||
if lhs.isAppOf ``WellFounded.fix then
|
||||
pure <| mkAppN (mkConst ``WellFounded.fix_eq lhs.getAppFn.constLevels!) lhs.getAppArgs
|
||||
else if lhs.isAppOf ``Order.fix then
|
||||
let x := lhs.getAppArgs.back!
|
||||
let args := lhs.getAppArgs.pop
|
||||
mkAppM ``congrFun #[mkAppN (mkConst ``Order.fix_eq lhs.getAppFn.constLevels!) args, x]
|
||||
else
|
||||
throwTacticEx `rwFixEq mvarId "expected fixed-point application"
|
||||
let some (_, _, lhsNew) := (← inferType h).eq? | unreachable!
|
||||
|
||||
-- lhs should be an application of the declNameNonrec, which unfolds to an
|
||||
-- application of fix in one step
|
||||
let some lhs' ← delta? lhs | throwError "rwFixEq: cannot delta-reduce {lhs}"
|
||||
let_expr WellFounded.fix _α _C _r _hwf F x := lhs'
|
||||
| throwTacticEx `rwFixEq mvarId "expected saturated fixed-point application in {lhs'}"
|
||||
let h := mkAppN (mkConst ``WellFounded.fix_eq lhs'.getAppFn.constLevels!) lhs'.getAppArgs
|
||||
|
||||
-- We used to just rewrite with `fix_eq` and continue with whatever RHS that produces, but that
|
||||
-- would include more copies of `fix` resulting in large and confusing terms.
|
||||
-- Instead we manually construct the new term in terms of the current functions,
|
||||
-- which should be headed by the `declNameNonRec`, and should be defeq to the expected type
|
||||
|
||||
-- if lhs == e x and lhs' == fix .., then lhsNew := e x = F x (fun y _ => e y)
|
||||
let ftype := (← inferType (mkApp F x)).bindingDomain!
|
||||
let f' ← forallBoundedTelescope ftype (some 2) fun ys _ => do
|
||||
mkLambdaFVars ys (.app lhs.appFn! ys[0]!)
|
||||
let lhsNew := mkApp2 F x f'
|
||||
let targetNew ← mkEq lhsNew rhs
|
||||
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew
|
||||
mvarId.assign (← mkEqTrans h mvarNew)
|
||||
return mvarNew.mvarId!
|
||||
|
||||
private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
|
||||
private partial def mkProof (declName declNameNonRec : Name) (type : Expr) : MetaM Expr := do
|
||||
trace[Elab.definition.wf.eqns] "proving: {type}"
|
||||
withNewMCtxDepth do
|
||||
let main ← mkFreshExprSyntheticOpaqueMVar type
|
||||
|
|
@ -83,7 +80,10 @@ private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
|
|||
-- LHS (introduced in 096e4eb), but it seems that code path was never used,
|
||||
-- so #3133 removed it again (and can be recovered from there if this was premature).
|
||||
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
|
||||
go (← rwFixEq (← deltaLHSUntilFix mvarId))
|
||||
|
||||
let mvarId ← if declName != declNameNonRec then deltaLHS mvarId else pure mvarId
|
||||
let mvarId ← rwFixEq mvarId
|
||||
go mvarId
|
||||
instantiateMVars main
|
||||
|
||||
def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
|
||||
|
|
@ -101,7 +101,7 @@ def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
|
|||
trace[Elab.definition.wf.eqns] "{eqnTypes[i]}"
|
||||
let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1)
|
||||
thmNames := thmNames.push name
|
||||
let value ← mkProof declName type
|
||||
let value ← mkProof declName info.declNameNonRec type
|
||||
let (type, value) ← removeUnusedEqnHypotheses type value
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name, type, value
|
||||
|
|
|
|||
|
|
@ -42,6 +42,12 @@ info: [simp] Diagnostics
|
|||
[simp] ack.eq_1 ↦ 768, succeeded: 768
|
||||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
info: [diag] Diagnostics
|
||||
[kernel] unfolded declarations (max: 29, num: 2):
|
||||
[kernel] Nat.casesOn ↦ 29
|
||||
[kernel] Nat.rec ↦ 29
|
||||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
error: tactic 'simp' failed, nested error:
|
||||
maximum recursion depth has been reached
|
||||
use `set_option maxRecDepth <num>` to increase limit
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue