fix: equation generation for nested recursive definitions
The issue was raised on Zulip. The issue is triggered in declarations containing overlapping patterns and nested recursive definitions occurring as the discriminant of `match`-expressions. Recall that Lean 4 generates conditional equations for declarations containing overlapping patterns. To address the issue we had to "fold" `WellFounded.fix` applications back as recursive applications of the functions being defined. The new test exposes the issue.
This commit is contained in:
parent
6652d2665d
commit
096e4eb6d0
2 changed files with 191 additions and 30 deletions
|
|
@ -39,17 +39,85 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := withMVarContext mvarId d
|
|||
private def hasWellFoundedFix (e : Expr) : Bool :=
|
||||
Option.isSome <| e.find? (·.isConstOf ``WellFounded.fix)
|
||||
|
||||
def simpMatchWF? (mvarId : MVarId) (info : EqnInfo) : MetaM (Option MVarId) := withMVarContext mvarId do
|
||||
let target ← instantiateMVars (← getMVarType mvarId)
|
||||
let targetNew ← Simp.main target (← Split.getSimpMatchContext) (methods := { pre })
|
||||
let mvarIdNew ← applySimpResultToTarget mvarId target targetNew
|
||||
if mvarId != mvarIdNew then return some mvarIdNew else return none
|
||||
/--
|
||||
Helper function for decoding the packed argument for a `WellFounded.fix` application.
|
||||
Recall that we use `PSum` and `PSigma` for packing the arguments of mutually recursive nary functions.
|
||||
-/
|
||||
private partial def decodePackedArg? (info : EqnInfo) (e : Expr) : Option (Name × Array Expr) := OptionM.run do
|
||||
if info.declNames.size == 1 then
|
||||
let args := decodePSigma e #[]
|
||||
return (info.declNames[0], args)
|
||||
else
|
||||
decodePSum? e 0
|
||||
where
|
||||
decodePSum? (e : Expr) (i : Nat) : Option (Name × Array Expr) := OptionM.run do
|
||||
if e.isAppOfArity ``PSum.inl 3 then
|
||||
decodePSum? e.appArg! i
|
||||
else if e.isAppOfArity ``PSum.inr 3 then
|
||||
decodePSum? e.appArg! (i+1)
|
||||
else
|
||||
guard (i < info.declNames.size)
|
||||
return (info.declNames[i], decodePSigma e #[])
|
||||
|
||||
decodePSigma (e : Expr) (acc : Array Expr) : Array Expr :=
|
||||
/- TODO: check arity of the given function. If it takes a PSigma as the last argument,
|
||||
this function will produce incorrect results. -/
|
||||
if e.isAppOfArity ``PSigma.mk 4 then
|
||||
decodePSigma e.appArg! (acc.push e.appFn!.appArg!)
|
||||
else
|
||||
acc.push e
|
||||
|
||||
/--
|
||||
Try to fold `WellFounded.fix` applications that represent recursive applications of the functions in `info.declNames`.
|
||||
We need that to make sure `simpMatchWF?` succeeds at goals such as
|
||||
```lean
|
||||
...
|
||||
h : g x = 0
|
||||
...
|
||||
|- (match (WellFounded.fix ...) with | ...) = ...
|
||||
```
|
||||
where `WellFounded.fix ...` can be folded back to `g x`.
|
||||
-/
|
||||
private def tryToFoldWellFoundedFix (info : EqnInfo) (us : List Level) (fixedPrefix : Array Expr) (e : Expr) : MetaM Expr := do
|
||||
if hasWellFoundedFix e then
|
||||
transform e (pre := pre)
|
||||
else
|
||||
return e
|
||||
where
|
||||
pre (e : Expr) : MetaM TransformStep := do
|
||||
let e' := e.headBeta
|
||||
if e'.isAppOf ``WellFounded.fix && e'.getAppNumArgs >= 6 then
|
||||
let args := e'.getAppArgs
|
||||
let packedArg := args[5]
|
||||
let extraArgs := args[6:]
|
||||
if let some (declName, args) := decodePackedArg? info packedArg then
|
||||
let candidate := mkAppN (mkAppN (mkAppN (mkConst declName us) fixedPrefix) args) extraArgs
|
||||
trace[Elab.definition.wf] "found nested WF at discr {candidate}"
|
||||
if (← withDefault <| isDefEq candidate e) then
|
||||
return TransformStep.visit candidate
|
||||
return TransformStep.visit e
|
||||
|
||||
/--
|
||||
Simplify `match`-expressions when trying to prove equation theorems for a recursive declaration defined using well-founded recursion.
|
||||
It is similar to `simpMatch?`, but is also tries to fold `WellFounded.fix` applications occurring in discriminants.
|
||||
See comment at `tryToFoldWellFoundedFix`.
|
||||
-/
|
||||
def simpMatchWF? (info : EqnInfo) (us : List Level) (fixedPrefix : Array Expr) (mvarId : MVarId) : MetaM (Option MVarId) :=
|
||||
withMVarContext mvarId do
|
||||
let target ← instantiateMVars (← getMVarType mvarId)
|
||||
let targetNew ← Simp.main target (← Split.getSimpMatchContext) (methods := { pre })
|
||||
let mvarIdNew ← applySimpResultToTarget mvarId target targetNew
|
||||
if mvarId != mvarIdNew then return some mvarIdNew else return none
|
||||
where
|
||||
pre (e : Expr) : SimpM Simp.Step := do
|
||||
let some app ← matchMatcherApp? e | return Simp.Step.visit { expr := e }
|
||||
if app.discrs.any hasWellFoundedFix then
|
||||
-- TODO: try to fold `WellFounded.fix` occurrences in the discriminant
|
||||
pure ()
|
||||
let discrsNew ← app.discrs.mapM (tryToFoldWellFoundedFix info us fixedPrefix ·)
|
||||
if discrsNew != app.discrs then
|
||||
let app := { app with discrs := discrsNew }
|
||||
let eNew := app.toExpr
|
||||
trace[Elab.definition.wf] "folded discriminants {indentExpr eNew}"
|
||||
return Simp.Step.visit { expr := app.toExpr }
|
||||
-- First try to reduce matcher
|
||||
match (← reduceRecMatcher? e) with
|
||||
| some e' => return Simp.Step.done { expr := e' }
|
||||
|
|
@ -58,36 +126,63 @@ where
|
|||
| some r => return r
|
||||
| none => return Simp.Step.visit { expr := e }
|
||||
|
||||
private def tryToFoldLHS? (info : EqnInfo) (us : List Level) (fixedPrefix : Array Expr) (mvarId : MVarId) : MetaM (Option MVarId) :=
|
||||
withMVarContext mvarId do
|
||||
let target ← getMVarType' mvarId
|
||||
let some (_, lhs, rhs) := target.eq? | unreachable!
|
||||
let lhsNew ← tryToFoldWellFoundedFix info us fixedPrefix lhs
|
||||
if lhs == lhsNew then return none
|
||||
let targetNew ← mkEq lhsNew rhs
|
||||
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew
|
||||
assignExprMVar mvarId mvarNew
|
||||
return mvarNew.mvarId!
|
||||
|
||||
/--
|
||||
Given a goal of the form `|- f.{us} a_1 ... a_n b_1 ... b_m = ...`, return `(us, #[a_1, ..., a_n])`
|
||||
where `f` is a constant named `declName`, and `n = info.fixedPrefixSize`.
|
||||
-/
|
||||
private def getFixedPrefix (declName : Name) (info : EqnInfo) (mvarId : MVarId) : MetaM (List Level × Array Expr) := withMVarContext mvarId do
|
||||
let target ← getMVarType' mvarId
|
||||
let some (_, lhs, rhs) := target.eq? | unreachable!
|
||||
let lhsArgs := lhs.getAppArgs
|
||||
if lhsArgs.size < info.fixedPrefixSize || !lhs.getAppFn matches .const .. then
|
||||
throwError "failed to generate equational theorem for '{declName}', unexpected number of arguments in the equation left-hand-side\n{mvarId}"
|
||||
let result := lhsArgs[:info.fixedPrefixSize]
|
||||
trace[Elab.definition.wf.eqns] "fixedPrefix: {result}"
|
||||
return (lhs.getAppFn.constLevels!, result)
|
||||
|
||||
private partial def mkProof (declName : Name) (info : EqnInfo) (type : Expr) : MetaM Expr := do
|
||||
trace[Elab.definition.wf.eqns] "proving: {type}"
|
||||
withNewMCtxDepth do
|
||||
let main ← mkFreshExprSyntheticOpaqueMVar type
|
||||
let (_, mvarId) ← intros main.mvarId!
|
||||
let (us, fixedPrefix) ← getFixedPrefix declName info mvarId
|
||||
let rec go (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
|
||||
if (← tryURefl mvarId) then
|
||||
return ()
|
||||
else if (← tryContradiction mvarId) then
|
||||
return ()
|
||||
else if let some mvarId ← simpMatchWF? info us fixedPrefix mvarId then
|
||||
go mvarId
|
||||
else if let some mvarId ← simpIf? mvarId then
|
||||
go mvarId
|
||||
else if let some mvarId ← whnfReducibleLHS? mvarId then
|
||||
go mvarId
|
||||
else match (← simpTargetStar mvarId {}) with
|
||||
| TacticResultCNM.closed => return ()
|
||||
| TacticResultCNM.modified mvarId => go mvarId
|
||||
| TacticResultCNM.noChange =>
|
||||
if let some mvarIds ← casesOnStuckLHS? mvarId then
|
||||
mvarIds.forM go
|
||||
else if let some mvarIds ← splitTarget? mvarId then
|
||||
mvarIds.forM go
|
||||
else if let some mvarId ← tryToFoldLHS? info us fixedPrefix mvarId then
|
||||
go mvarId
|
||||
else
|
||||
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
|
||||
go (← rwFixEq (← deltaLHSUntilFix mvarId))
|
||||
instantiateMVars main
|
||||
where
|
||||
go (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
|
||||
if (← tryURefl mvarId) then
|
||||
return ()
|
||||
else if (← tryContradiction mvarId) then
|
||||
return ()
|
||||
else if let some mvarId ← simpMatchWF? mvarId info then
|
||||
go mvarId
|
||||
else if let some mvarId ← simpIf? mvarId then
|
||||
go mvarId
|
||||
else if let some mvarId ← whnfReducibleLHS? mvarId then
|
||||
go mvarId
|
||||
else match (← simpTargetStar mvarId {}) with
|
||||
| TacticResultCNM.closed => return ()
|
||||
| TacticResultCNM.modified mvarId => go mvarId
|
||||
| TacticResultCNM.noChange =>
|
||||
if let some mvarIds ← casesOnStuckLHS? mvarId then
|
||||
mvarIds.forM go
|
||||
else if let some mvarIds ← splitTarget? mvarId then
|
||||
mvarIds.forM go
|
||||
else
|
||||
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
|
||||
|
||||
def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
|
||||
withOptions (tactic.hygienic.set · false) do
|
||||
|
|
|
|||
66
tests/lean/run/nestedWF.lean
Normal file
66
tests/lean/run/nestedWF.lean
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
namespace Ex1
|
||||
|
||||
mutual
|
||||
def h (c : Nat) (x : Nat) := match g c x c c with
|
||||
| 0 => 1
|
||||
| r => r + 2
|
||||
def g (c : Nat) (t : Nat) (a b : Nat) : Nat := match t with
|
||||
| (n+1) => match g c n a b with
|
||||
| 0 => 0
|
||||
| m => match g c (n - m) a b with
|
||||
| 0 => 0
|
||||
| m + 1 => g c m a b
|
||||
| 0 => f c 0
|
||||
def f (c : Nat) (x : Nat) := match h c x with
|
||||
| 0 => 1
|
||||
| r => f c r
|
||||
end
|
||||
termination_by
|
||||
g x a b => 0
|
||||
f c x => 0
|
||||
h c x => 0
|
||||
decreasing_by sorry
|
||||
|
||||
attribute [simp] g
|
||||
attribute [simp] h
|
||||
attribute [simp] f
|
||||
|
||||
#check g._eq_1
|
||||
#check g._eq_2
|
||||
#check g._eq_3
|
||||
#check g._eq_4
|
||||
|
||||
#check h._eq_1
|
||||
|
||||
#check f._eq_1
|
||||
#check f._eq_2
|
||||
|
||||
end Ex1
|
||||
|
||||
namespace Ex2
|
||||
|
||||
def g (t : Nat) : Nat := match t with
|
||||
| (n+1) => match g n with
|
||||
| 0 => 0
|
||||
| m + 1 => match g (n - m) with
|
||||
| 0 => 0
|
||||
| m + 1 => g n
|
||||
| 0 => 0
|
||||
termination_by' sorry
|
||||
decreasing_by sorry
|
||||
|
||||
theorem ex1 : g 0 = 0 := by
|
||||
rw [g]
|
||||
|
||||
#check g._eq_1
|
||||
#check g._eq_2
|
||||
#check g._eq_3
|
||||
|
||||
theorem ex2 : g 0 = 0 := by
|
||||
unfold g
|
||||
simp
|
||||
|
||||
#check g._unfold
|
||||
|
||||
|
||||
end Ex2
|
||||
Loading…
Add table
Reference in a new issue