fix: equation generation for nested recursive definitions

The issue was raised on Zulip. The issue is triggered in
declarations containing overlapping patterns and nested recursive
definitions occurring as the discriminant of `match`-expressions.
Recall that Lean 4 generates conditional equations for declarations
containing overlapping patterns.
To address the issue we had to "fold" `WellFounded.fix` applications
back as recursive applications of the functions being defined.

The new test exposes the issue.
This commit is contained in:
Leonardo de Moura 2022-03-31 17:04:06 -07:00
parent 6652d2665d
commit 096e4eb6d0
2 changed files with 191 additions and 30 deletions

View file

@ -39,17 +39,85 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := withMVarContext mvarId d
private def hasWellFoundedFix (e : Expr) : Bool :=
Option.isSome <| e.find? (·.isConstOf ``WellFounded.fix)
def simpMatchWF? (mvarId : MVarId) (info : EqnInfo) : MetaM (Option MVarId) := withMVarContext mvarId do
let target ← instantiateMVars (← getMVarType mvarId)
let targetNew ← Simp.main target (← Split.getSimpMatchContext) (methods := { pre })
let mvarIdNew ← applySimpResultToTarget mvarId target targetNew
if mvarId != mvarIdNew then return some mvarIdNew else return none
/--
Helper function for decoding the packed argument for a `WellFounded.fix` application.
Recall that we use `PSum` and `PSigma` for packing the arguments of mutually recursive nary functions.
-/
private partial def decodePackedArg? (info : EqnInfo) (e : Expr) : Option (Name × Array Expr) := OptionM.run do
if info.declNames.size == 1 then
let args := decodePSigma e #[]
return (info.declNames[0], args)
else
decodePSum? e 0
where
decodePSum? (e : Expr) (i : Nat) : Option (Name × Array Expr) := OptionM.run do
if e.isAppOfArity ``PSum.inl 3 then
decodePSum? e.appArg! i
else if e.isAppOfArity ``PSum.inr 3 then
decodePSum? e.appArg! (i+1)
else
guard (i < info.declNames.size)
return (info.declNames[i], decodePSigma e #[])
decodePSigma (e : Expr) (acc : Array Expr) : Array Expr :=
/- TODO: check arity of the given function. If it takes a PSigma as the last argument,
this function will produce incorrect results. -/
if e.isAppOfArity ``PSigma.mk 4 then
decodePSigma e.appArg! (acc.push e.appFn!.appArg!)
else
acc.push e
/--
Try to fold `WellFounded.fix` applications that represent recursive applications of the functions in `info.declNames`.
We need that to make sure `simpMatchWF?` succeeds at goals such as
```lean
...
h : g x = 0
...
|- (match (WellFounded.fix ...) with | ...) = ...
```
where `WellFounded.fix ...` can be folded back to `g x`.
-/
private def tryToFoldWellFoundedFix (info : EqnInfo) (us : List Level) (fixedPrefix : Array Expr) (e : Expr) : MetaM Expr := do
if hasWellFoundedFix e then
transform e (pre := pre)
else
return e
where
pre (e : Expr) : MetaM TransformStep := do
let e' := e.headBeta
if e'.isAppOf ``WellFounded.fix && e'.getAppNumArgs >= 6 then
let args := e'.getAppArgs
let packedArg := args[5]
let extraArgs := args[6:]
if let some (declName, args) := decodePackedArg? info packedArg then
let candidate := mkAppN (mkAppN (mkAppN (mkConst declName us) fixedPrefix) args) extraArgs
trace[Elab.definition.wf] "found nested WF at discr {candidate}"
if (← withDefault <| isDefEq candidate e) then
return TransformStep.visit candidate
return TransformStep.visit e
/--
Simplify `match`-expressions when trying to prove equation theorems for a recursive declaration defined using well-founded recursion.
It is similar to `simpMatch?`, but is also tries to fold `WellFounded.fix` applications occurring in discriminants.
See comment at `tryToFoldWellFoundedFix`.
-/
def simpMatchWF? (info : EqnInfo) (us : List Level) (fixedPrefix : Array Expr) (mvarId : MVarId) : MetaM (Option MVarId) :=
withMVarContext mvarId do
let target ← instantiateMVars (← getMVarType mvarId)
let targetNew ← Simp.main target (← Split.getSimpMatchContext) (methods := { pre })
let mvarIdNew ← applySimpResultToTarget mvarId target targetNew
if mvarId != mvarIdNew then return some mvarIdNew else return none
where
pre (e : Expr) : SimpM Simp.Step := do
let some app ← matchMatcherApp? e | return Simp.Step.visit { expr := e }
if app.discrs.any hasWellFoundedFix then
-- TODO: try to fold `WellFounded.fix` occurrences in the discriminant
pure ()
let discrsNew ← app.discrs.mapM (tryToFoldWellFoundedFix info us fixedPrefix ·)
if discrsNew != app.discrs then
let app := { app with discrs := discrsNew }
let eNew := app.toExpr
trace[Elab.definition.wf] "folded discriminants {indentExpr eNew}"
return Simp.Step.visit { expr := app.toExpr }
-- First try to reduce matcher
match (← reduceRecMatcher? e) with
| some e' => return Simp.Step.done { expr := e' }
@ -58,36 +126,63 @@ where
| some r => return r
| none => return Simp.Step.visit { expr := e }
private def tryToFoldLHS? (info : EqnInfo) (us : List Level) (fixedPrefix : Array Expr) (mvarId : MVarId) : MetaM (Option MVarId) :=
withMVarContext mvarId do
let target ← getMVarType' mvarId
let some (_, lhs, rhs) := target.eq? | unreachable!
let lhsNew ← tryToFoldWellFoundedFix info us fixedPrefix lhs
if lhs == lhsNew then return none
let targetNew ← mkEq lhsNew rhs
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew
assignExprMVar mvarId mvarNew
return mvarNew.mvarId!
/--
Given a goal of the form `|- f.{us} a_1 ... a_n b_1 ... b_m = ...`, return `(us, #[a_1, ..., a_n])`
where `f` is a constant named `declName`, and `n = info.fixedPrefixSize`.
-/
private def getFixedPrefix (declName : Name) (info : EqnInfo) (mvarId : MVarId) : MetaM (List Level × Array Expr) := withMVarContext mvarId do
let target ← getMVarType' mvarId
let some (_, lhs, rhs) := target.eq? | unreachable!
let lhsArgs := lhs.getAppArgs
if lhsArgs.size < info.fixedPrefixSize || !lhs.getAppFn matches .const .. then
throwError "failed to generate equational theorem for '{declName}', unexpected number of arguments in the equation left-hand-side\n{mvarId}"
let result := lhsArgs[:info.fixedPrefixSize]
trace[Elab.definition.wf.eqns] "fixedPrefix: {result}"
return (lhs.getAppFn.constLevels!, result)
private partial def mkProof (declName : Name) (info : EqnInfo) (type : Expr) : MetaM Expr := do
trace[Elab.definition.wf.eqns] "proving: {type}"
withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) ← intros main.mvarId!
let (us, fixedPrefix) ← getFixedPrefix declName info mvarId
let rec go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
if (← tryURefl mvarId) then
return ()
else if (← tryContradiction mvarId) then
return ()
else if let some mvarId ← simpMatchWF? info us fixedPrefix mvarId then
go mvarId
else if let some mvarId ← simpIf? mvarId then
go mvarId
else if let some mvarId ← whnfReducibleLHS? mvarId then
go mvarId
else match (← simpTargetStar mvarId {}) with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.noChange =>
if let some mvarIds ← casesOnStuckLHS? mvarId then
mvarIds.forM go
else if let some mvarIds ← splitTarget? mvarId then
mvarIds.forM go
else if let some mvarId ← tryToFoldLHS? info us fixedPrefix mvarId then
go mvarId
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
go (← rwFixEq (← deltaLHSUntilFix mvarId))
instantiateMVars main
where
go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
if (← tryURefl mvarId) then
return ()
else if (← tryContradiction mvarId) then
return ()
else if let some mvarId ← simpMatchWF? mvarId info then
go mvarId
else if let some mvarId ← simpIf? mvarId then
go mvarId
else if let some mvarId ← whnfReducibleLHS? mvarId then
go mvarId
else match (← simpTargetStar mvarId {}) with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.noChange =>
if let some mvarIds ← casesOnStuckLHS? mvarId then
mvarIds.forM go
else if let some mvarIds ← splitTarget? mvarId then
mvarIds.forM go
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
withOptions (tactic.hygienic.set · false) do

View file

@ -0,0 +1,66 @@
namespace Ex1
mutual
def h (c : Nat) (x : Nat) := match g c x c c with
| 0 => 1
| r => r + 2
def g (c : Nat) (t : Nat) (a b : Nat) : Nat := match t with
| (n+1) => match g c n a b with
| 0 => 0
| m => match g c (n - m) a b with
| 0 => 0
| m + 1 => g c m a b
| 0 => f c 0
def f (c : Nat) (x : Nat) := match h c x with
| 0 => 1
| r => f c r
end
termination_by
g x a b => 0
f c x => 0
h c x => 0
decreasing_by sorry
attribute [simp] g
attribute [simp] h
attribute [simp] f
#check g._eq_1
#check g._eq_2
#check g._eq_3
#check g._eq_4
#check h._eq_1
#check f._eq_1
#check f._eq_2
end Ex1
namespace Ex2
def g (t : Nat) : Nat := match t with
| (n+1) => match g n with
| 0 => 0
| m + 1 => match g (n - m) with
| 0 => 0
| m + 1 => g n
| 0 => 0
termination_by' sorry
decreasing_by sorry
theorem ex1 : g 0 = 0 := by
rw [g]
#check g._eq_1
#check g._eq_2
#check g._eq_3
theorem ex2 : g 0 = 0 := by
unfold g
simp
#check g._unfold
end Ex2