From 096e4eb6d0f56ba395bb658bea87518a8baa4b57 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 31 Mar 2022 17:04:06 -0700 Subject: [PATCH] fix: equation generation for nested recursive definitions The issue was raised on Zulip. The issue is triggered in declarations containing overlapping patterns and nested recursive definitions occurring as the discriminant of `match`-expressions. Recall that Lean 4 generates conditional equations for declarations containing overlapping patterns. To address the issue we had to "fold" `WellFounded.fix` applications back as recursive applications of the functions being defined. The new test exposes the issue. --- src/Lean/Elab/PreDefinition/WF/Eqns.lean | 155 ++++++++++++++++++----- tests/lean/run/nestedWF.lean | 66 ++++++++++ 2 files changed, 191 insertions(+), 30 deletions(-) create mode 100644 tests/lean/run/nestedWF.lean diff --git a/src/Lean/Elab/PreDefinition/WF/Eqns.lean b/src/Lean/Elab/PreDefinition/WF/Eqns.lean index c7e7342100..38c00903a1 100644 --- a/src/Lean/Elab/PreDefinition/WF/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/WF/Eqns.lean @@ -39,17 +39,85 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := withMVarContext mvarId d private def hasWellFoundedFix (e : Expr) : Bool := Option.isSome <| e.find? (·.isConstOf ``WellFounded.fix) -def simpMatchWF? (mvarId : MVarId) (info : EqnInfo) : MetaM (Option MVarId) := withMVarContext mvarId do - let target ← instantiateMVars (← getMVarType mvarId) - let targetNew ← Simp.main target (← Split.getSimpMatchContext) (methods := { pre }) - let mvarIdNew ← applySimpResultToTarget mvarId target targetNew - if mvarId != mvarIdNew then return some mvarIdNew else return none +/-- + Helper function for decoding the packed argument for a `WellFounded.fix` application. + Recall that we use `PSum` and `PSigma` for packing the arguments of mutually recursive nary functions. +-/ +private partial def decodePackedArg? (info : EqnInfo) (e : Expr) : Option (Name × Array Expr) := OptionM.run do + if info.declNames.size == 1 then + let args := decodePSigma e #[] + return (info.declNames[0], args) + else + decodePSum? e 0 +where + decodePSum? (e : Expr) (i : Nat) : Option (Name × Array Expr) := OptionM.run do + if e.isAppOfArity ``PSum.inl 3 then + decodePSum? e.appArg! i + else if e.isAppOfArity ``PSum.inr 3 then + decodePSum? e.appArg! (i+1) + else + guard (i < info.declNames.size) + return (info.declNames[i], decodePSigma e #[]) + + decodePSigma (e : Expr) (acc : Array Expr) : Array Expr := + /- TODO: check arity of the given function. If it takes a PSigma as the last argument, + this function will produce incorrect results. -/ + if e.isAppOfArity ``PSigma.mk 4 then + decodePSigma e.appArg! (acc.push e.appFn!.appArg!) + else + acc.push e + +/-- + Try to fold `WellFounded.fix` applications that represent recursive applications of the functions in `info.declNames`. + We need that to make sure `simpMatchWF?` succeeds at goals such as + ```lean + ... + h : g x = 0 + ... + |- (match (WellFounded.fix ...) with | ...) = ... + ``` + where `WellFounded.fix ...` can be folded back to `g x`. +-/ +private def tryToFoldWellFoundedFix (info : EqnInfo) (us : List Level) (fixedPrefix : Array Expr) (e : Expr) : MetaM Expr := do + if hasWellFoundedFix e then + transform e (pre := pre) + else + return e +where + pre (e : Expr) : MetaM TransformStep := do + let e' := e.headBeta + if e'.isAppOf ``WellFounded.fix && e'.getAppNumArgs >= 6 then + let args := e'.getAppArgs + let packedArg := args[5] + let extraArgs := args[6:] + if let some (declName, args) := decodePackedArg? info packedArg then + let candidate := mkAppN (mkAppN (mkAppN (mkConst declName us) fixedPrefix) args) extraArgs + trace[Elab.definition.wf] "found nested WF at discr {candidate}" + if (← withDefault <| isDefEq candidate e) then + return TransformStep.visit candidate + return TransformStep.visit e + +/-- + Simplify `match`-expressions when trying to prove equation theorems for a recursive declaration defined using well-founded recursion. + It is similar to `simpMatch?`, but is also tries to fold `WellFounded.fix` applications occurring in discriminants. + See comment at `tryToFoldWellFoundedFix`. +-/ +def simpMatchWF? (info : EqnInfo) (us : List Level) (fixedPrefix : Array Expr) (mvarId : MVarId) : MetaM (Option MVarId) := + withMVarContext mvarId do + let target ← instantiateMVars (← getMVarType mvarId) + let targetNew ← Simp.main target (← Split.getSimpMatchContext) (methods := { pre }) + let mvarIdNew ← applySimpResultToTarget mvarId target targetNew + if mvarId != mvarIdNew then return some mvarIdNew else return none where pre (e : Expr) : SimpM Simp.Step := do let some app ← matchMatcherApp? e | return Simp.Step.visit { expr := e } if app.discrs.any hasWellFoundedFix then - -- TODO: try to fold `WellFounded.fix` occurrences in the discriminant - pure () + let discrsNew ← app.discrs.mapM (tryToFoldWellFoundedFix info us fixedPrefix ·) + if discrsNew != app.discrs then + let app := { app with discrs := discrsNew } + let eNew := app.toExpr + trace[Elab.definition.wf] "folded discriminants {indentExpr eNew}" + return Simp.Step.visit { expr := app.toExpr } -- First try to reduce matcher match (← reduceRecMatcher? e) with | some e' => return Simp.Step.done { expr := e' } @@ -58,36 +126,63 @@ where | some r => return r | none => return Simp.Step.visit { expr := e } +private def tryToFoldLHS? (info : EqnInfo) (us : List Level) (fixedPrefix : Array Expr) (mvarId : MVarId) : MetaM (Option MVarId) := + withMVarContext mvarId do + let target ← getMVarType' mvarId + let some (_, lhs, rhs) := target.eq? | unreachable! + let lhsNew ← tryToFoldWellFoundedFix info us fixedPrefix lhs + if lhs == lhsNew then return none + let targetNew ← mkEq lhsNew rhs + let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew + assignExprMVar mvarId mvarNew + return mvarNew.mvarId! + +/-- + Given a goal of the form `|- f.{us} a_1 ... a_n b_1 ... b_m = ...`, return `(us, #[a_1, ..., a_n])` + where `f` is a constant named `declName`, and `n = info.fixedPrefixSize`. +-/ +private def getFixedPrefix (declName : Name) (info : EqnInfo) (mvarId : MVarId) : MetaM (List Level × Array Expr) := withMVarContext mvarId do + let target ← getMVarType' mvarId + let some (_, lhs, rhs) := target.eq? | unreachable! + let lhsArgs := lhs.getAppArgs + if lhsArgs.size < info.fixedPrefixSize || !lhs.getAppFn matches .const .. then + throwError "failed to generate equational theorem for '{declName}', unexpected number of arguments in the equation left-hand-side\n{mvarId}" + let result := lhsArgs[:info.fixedPrefixSize] + trace[Elab.definition.wf.eqns] "fixedPrefix: {result}" + return (lhs.getAppFn.constLevels!, result) + private partial def mkProof (declName : Name) (info : EqnInfo) (type : Expr) : MetaM Expr := do trace[Elab.definition.wf.eqns] "proving: {type}" withNewMCtxDepth do let main ← mkFreshExprSyntheticOpaqueMVar type let (_, mvarId) ← intros main.mvarId! + let (us, fixedPrefix) ← getFixedPrefix declName info mvarId + let rec go (mvarId : MVarId) : MetaM Unit := do + trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}" + if (← tryURefl mvarId) then + return () + else if (← tryContradiction mvarId) then + return () + else if let some mvarId ← simpMatchWF? info us fixedPrefix mvarId then + go mvarId + else if let some mvarId ← simpIf? mvarId then + go mvarId + else if let some mvarId ← whnfReducibleLHS? mvarId then + go mvarId + else match (← simpTargetStar mvarId {}) with + | TacticResultCNM.closed => return () + | TacticResultCNM.modified mvarId => go mvarId + | TacticResultCNM.noChange => + if let some mvarIds ← casesOnStuckLHS? mvarId then + mvarIds.forM go + else if let some mvarIds ← splitTarget? mvarId then + mvarIds.forM go + else if let some mvarId ← tryToFoldLHS? info us fixedPrefix mvarId then + go mvarId + else + throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}" go (← rwFixEq (← deltaLHSUntilFix mvarId)) instantiateMVars main -where - go (mvarId : MVarId) : MetaM Unit := do - trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}" - if (← tryURefl mvarId) then - return () - else if (← tryContradiction mvarId) then - return () - else if let some mvarId ← simpMatchWF? mvarId info then - go mvarId - else if let some mvarId ← simpIf? mvarId then - go mvarId - else if let some mvarId ← whnfReducibleLHS? mvarId then - go mvarId - else match (← simpTargetStar mvarId {}) with - | TacticResultCNM.closed => return () - | TacticResultCNM.modified mvarId => go mvarId - | TacticResultCNM.noChange => - if let some mvarIds ← casesOnStuckLHS? mvarId then - mvarIds.forM go - else if let some mvarIds ← splitTarget? mvarId then - mvarIds.forM go - else - throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}" def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) := withOptions (tactic.hygienic.set · false) do diff --git a/tests/lean/run/nestedWF.lean b/tests/lean/run/nestedWF.lean new file mode 100644 index 0000000000..feb5332d24 --- /dev/null +++ b/tests/lean/run/nestedWF.lean @@ -0,0 +1,66 @@ +namespace Ex1 + +mutual +def h (c : Nat) (x : Nat) := match g c x c c with + | 0 => 1 + | r => r + 2 +def g (c : Nat) (t : Nat) (a b : Nat) : Nat := match t with + | (n+1) => match g c n a b with + | 0 => 0 + | m => match g c (n - m) a b with + | 0 => 0 + | m + 1 => g c m a b + | 0 => f c 0 +def f (c : Nat) (x : Nat) := match h c x with + | 0 => 1 + | r => f c r +end +termination_by + g x a b => 0 + f c x => 0 + h c x => 0 +decreasing_by sorry + +attribute [simp] g +attribute [simp] h +attribute [simp] f + +#check g._eq_1 +#check g._eq_2 +#check g._eq_3 +#check g._eq_4 + +#check h._eq_1 + +#check f._eq_1 +#check f._eq_2 + +end Ex1 + +namespace Ex2 + +def g (t : Nat) : Nat := match t with + | (n+1) => match g n with + | 0 => 0 + | m + 1 => match g (n - m) with + | 0 => 0 + | m + 1 => g n + | 0 => 0 +termination_by' sorry +decreasing_by sorry + +theorem ex1 : g 0 = 0 := by + rw [g] + +#check g._eq_1 +#check g._eq_2 +#check g._eq_3 + +theorem ex2 : g 0 = 0 := by + unfold g + simp + +#check g._unfold + + +end Ex2