From 3418d6db8ebdec9a8d08316789220cf51336ba15 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Mon, 27 Jan 2025 15:00:55 +0100 Subject: [PATCH] fix: more robust equational theorems generation for partial_fixpoint (#6790) This PR fixes an issue with the generation of equational theorems from `partial_fixpoint` when case-splitting is necessary. Fixes #6786. --- src/Lean/Elab/PreDefinition/Eqns.lean | 109 ++++++++++++++++++ src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean | 63 +--------- .../PreDefinition/PartialFixpoint/Eqns.lean | 104 ++++++++--------- tests/lean/run/issue6786.lean | 101 ++++++++++++++++ 4 files changed, 264 insertions(+), 113 deletions(-) create mode 100644 tests/lean/run/issue6786.lean diff --git a/src/Lean/Elab/PreDefinition/Eqns.lean b/src/Lean/Elab/PreDefinition/Eqns.lean index be32f60c7d..a55513c6bf 100644 --- a/src/Lean/Elab/PreDefinition/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Eqns.lean @@ -308,6 +308,115 @@ def whnfReducibleLHS? (mvarId : MVarId) : MetaM (Option MVarId) := mvarId.withCo def tryContradiction (mvarId : MVarId) : MetaM Bool := do mvarId.contradictionCore { genDiseq := true } +/-- +Returns the type of the unfold theorem, as the starting point for calculating the equational +types. +-/ +private def unfoldThmType (declName : Name) : MetaM Expr := do + if let some unfoldThm ← getUnfoldEqnFor? declName (nonRec := false) then + let info ← getConstInfo unfoldThm + pure info.type + else + let info ← getConstInfoDefn declName + let us := info.levelParams.map mkLevelParam + lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do + let type ← mkEq (mkAppN (Lean.mkConst declName us) xs) body + mkForallFVars xs type + +private def unfoldLHS (declName : Name) (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do + if let some unfoldThm ← getUnfoldEqnFor? declName (nonRec := false) then + -- Recursive definition: Use unfolding lemma + let mut mvarId := mvarId + let target ← mvarId.getType' + let some (_, lhs, rhs) := target.eq? | throwError "unfoldLHS: Unexpected target {target}" + unless lhs.isAppOf declName do throwError "unfoldLHS: Unexpected LHS {lhs}" + let h := mkAppN (.const unfoldThm lhs.getAppFn.constLevels!) lhs.getAppArgs + let some (_, _, lhsNew) := (← inferType h).eq? | unreachable! + let targetNew ← mkEq lhsNew rhs + let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew + mvarId.assign (← mkEqTrans h mvarNew) + return mvarNew.mvarId! + else + -- Else use delta reduction + deltaLHS mvarId + +private partial def mkEqnProof (declName : Name) (type : Expr) : MetaM Expr := do + trace[Elab.definition.eqns] "proving: {type}" + withNewMCtxDepth do + let main ← mkFreshExprSyntheticOpaqueMVar type + let (_, mvarId) ← main.mvarId!.intros + -- Try rfl before deltaLHS to avoid `id` checkpoints in the proof, which would make + -- the lemma ineligible for dsimp + unless ← withAtLeastTransparency .all (tryURefl mvarId) do + go (← unfoldLHS declName mvarId) + instantiateMVars main + where + /-- + The core loop of proving an equation. Assumes that the function call on the left-hand side has + already been unfolded, using whatever method applies to the current function definition strategy. + + Currently used for non-recursive functions and partial fixpoints; maybe later well-founded + recursion and structural recursion can and should use this too. + -/ + go (mvarId : MVarId) : MetaM Unit := do + trace[Elab.definition.eqns] "step\n{MessageData.ofGoal mvarId}" + if ← withAtLeastTransparency .all (tryURefl mvarId) then + return () + else if (← tryContradiction mvarId) then + return () + else if let some mvarId ← simpMatch? mvarId then + go mvarId + else if let some mvarId ← simpIf? mvarId then + go mvarId + else if let some mvarId ← whnfReducibleLHS? mvarId then + go mvarId + else + let ctx ← Simp.mkContext (config := { dsimp := false }) + match (← simpTargetStar mvarId ctx (simprocs := {})).1 with + | TacticResultCNM.closed => return () + | TacticResultCNM.modified mvarId => go mvarId + | TacticResultCNM.noChange => + if let some mvarIds ← casesOnStuckLHS? mvarId then + mvarIds.forM go + else if let some mvarIds ← splitTarget? mvarId then + mvarIds.forM go + else + throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}" + + +/-- +Generate equations for `declName`. + +This unfolds the function application on the LHS (using an unfold theorem, if present, or else by +delta-reduction), calculates the types for the equational theorems using `mkEqnTypes`, and then +proves them using `mkEqnProof`. + +This is currently used for non-recursive functions and for functions defined by partial_fixpoint. +-/ +def mkEqns (declName : Name) : MetaM (Array Name) := do + let info ← getConstInfoDefn declName + let us := info.levelParams.map mkLevelParam + withOptions (tactic.hygienic.set · false) do + let target ← unfoldThmType declName + let eqnTypes ← withNewMCtxDepth <| + forallTelescope (cleanupAnnotations := true) target fun xs target => do + let goal ← mkFreshExprSyntheticOpaqueMVar target + withReducible do + mkEqnTypes #[] goal.mvarId! + let mut thmNames := #[] + for h : i in [: eqnTypes.size] do + let type := eqnTypes[i] + trace[Elab.definition.eqns] "eqnType[{i}]: {eqnTypes[i]}" + let name := (Name.str declName eqnThmSuffixBase).appendIndexAfter (i+1) + thmNames := thmNames.push name + let value ← mkEqnProof declName type + let (type, value) ← removeUnusedEqnHypotheses type value + addDecl <| Declaration.thmDecl { + name, type, value + levelParams := info.levelParams + } + return thmNames + /-- Auxiliary method for `mkUnfoldEq`. The structure is based on `mkEqnTypes`. `mvarId` is the goal to be proved. It is a goal of the form diff --git a/src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean b/src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean index 91c5f8f151..386cfdb26b 100644 --- a/src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean @@ -33,71 +33,12 @@ private def mkSimpleEqThm (declName : Name) (suffix := Name.mkSimple unfoldThmSu else return none -private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do - trace[Elab.definition.eqns] "proving: {type}" - withNewMCtxDepth do - let main ← mkFreshExprSyntheticOpaqueMVar type - let (_, mvarId) ← main.mvarId!.intros - let rec go (mvarId : MVarId) : MetaM Unit := do - trace[Elab.definition.eqns] "step\n{MessageData.ofGoal mvarId}" - if ← withAtLeastTransparency .all (tryURefl mvarId) then - return () - else if (← tryContradiction mvarId) then - return () - else if let some mvarId ← simpMatch? mvarId then - go mvarId - else if let some mvarId ← simpIf? mvarId then - go mvarId - else if let some mvarId ← whnfReducibleLHS? mvarId then - go mvarId - else - let ctx ← Simp.mkContext (config := { dsimp := false }) - match (← simpTargetStar mvarId ctx (simprocs := {})).1 with - | TacticResultCNM.closed => return () - | TacticResultCNM.modified mvarId => go mvarId - | TacticResultCNM.noChange => - if let some mvarIds ← casesOnStuckLHS? mvarId then - mvarIds.forM go - else if let some mvarIds ← splitTarget? mvarId then - mvarIds.forM go - else - throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}" - - -- Try rfl before deltaLHS to avoid `id` checkpoints in the proof, which would make - -- the lemma ineligible for dsimp - unless ← withAtLeastTransparency .all (tryURefl mvarId) do - go (← deltaLHS mvarId) - instantiateMVars main - -def mkEqns (declName : Name) (info : DefinitionVal) : MetaM (Array Name) := - withOptions (tactic.hygienic.set · false) do - let baseName := declName - let eqnTypes ← withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do - let us := info.levelParams.map mkLevelParam - let target ← mkEq (mkAppN (Lean.mkConst declName us) xs) body - let goal ← mkFreshExprSyntheticOpaqueMVar target - withReducible do - mkEqnTypes #[] goal.mvarId! - let mut thmNames := #[] - for h : i in [: eqnTypes.size] do - let type := eqnTypes[i] - trace[Elab.definition.eqns] "eqnType[{i}]: {eqnTypes[i]}" - let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1) - thmNames := thmNames.push name - let value ← mkProof declName type - let (type, value) ← removeUnusedEqnHypotheses type value - addDecl <| Declaration.thmDecl { - name, type, value - levelParams := info.levelParams - } - return thmNames - def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do if (← isRecursiveDefinition declName) then return none - if let some (.defnInfo info) := (← getEnv).find? declName then + if (← getEnv).contains declName then if backward.eqns.nonrecursive.get (← getOptions) then - mkEqns declName info + mkEqns declName else let o ← mkSimpleEqThm declName return o.map (#[·]) diff --git a/src/Lean/Elab/PreDefinition/PartialFixpoint/Eqns.lean b/src/Lean/Elab/PreDefinition/PartialFixpoint/Eqns.lean index 32b0790b6a..edd82ac2a5 100644 --- a/src/Lean/Elab/PreDefinition/PartialFixpoint/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/PartialFixpoint/Eqns.lean @@ -23,6 +23,18 @@ structure EqnInfo extends EqnInfoCore where fixedPrefixSize : Nat deriving Inhabited +builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension + +def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat) : MetaM Unit := do + preDefs.forM fun preDef => ensureEqnReservedNamesAvailable preDef.declName + unless preDefs.all fun p => p.kind.isTheorem do + unless (← preDefs.allM fun p => isProp p.type) do + let declNames := preDefs.map (·.declName) + modifyEnv fun env => + preDefs.foldl (init := env) fun env preDef => + eqnInfoExt.insert env preDef.declName { preDef with + declNames, declNameNonRec, fixedPrefixSize } + private def deltaLHSUntilFix (declName declNameNonRec : Name) (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do let target ← mvarId.getType' let some (_, lhs, rhs) := target.eq? | throwTacticEx `deltaLHSUntilFix mvarId "equality expected" @@ -53,62 +65,50 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do mvarId.assign (← mkEqTrans h mvarNew) return mvarNew.mvarId! -private partial def mkProof (declName : Name) (declNameNonRec : Name) (type : Expr) : MetaM Expr := do - trace[Elab.definition.partialFixpoint] "proving: {type}" - withNewMCtxDepth do - let main ← mkFreshExprSyntheticOpaqueMVar type - let (_, mvarId) ← main.mvarId!.intros - let mvarId ← deltaLHSUntilFix declName declNameNonRec mvarId - let mvarId ← rwFixEq mvarId - if ← withAtLeastTransparency .all (tryURefl mvarId) then - instantiateMVars main - else - throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}" - -def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) := +/-- Generate the "unfold" lemma for `declName`. -/ +def mkUnfoldEq (declName : Name) (info : EqnInfo) : MetaM Name := withLCtx {} {} do withOptions (tactic.hygienic.set · false) do - let baseName := declName - let eqnTypes ← withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do - let us := info.levelParams.map mkLevelParam - let target ← mkEq (mkAppN (Lean.mkConst declName us) xs) body - let goal ← mkFreshExprSyntheticOpaqueMVar target - withReducible do - mkEqnTypes info.declNames goal.mvarId! - let mut thmNames := #[] - for h : i in [: eqnTypes.size] do - let type := eqnTypes[i] - trace[Elab.definition.partialFixpoint] "{eqnTypes[i]}" - let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1) - thmNames := thmNames.push name - let value ← mkProof declName info.declNameNonRec type - let (type, value) ← removeUnusedEqnHypotheses type value - addDecl <| Declaration.thmDecl { - name, type, value - levelParams := info.levelParams - } - return thmNames - -builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension - -def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat) : MetaM Unit := do - preDefs.forM fun preDef => ensureEqnReservedNamesAvailable preDef.declName - unless preDefs.all fun p => p.kind.isTheorem do - unless (← preDefs.allM fun p => isProp p.type) do - let declNames := preDefs.map (·.declName) - modifyEnv fun env => - preDefs.foldl (init := env) fun env preDef => - eqnInfoExt.insert env preDef.declName { preDef with - declNames, declNameNonRec, fixedPrefixSize } - -def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do - if let some info := eqnInfoExt.find? (← getEnv) declName then - mkEqns declName info - else - return none + let baseName := declName + lambdaTelescope info.value fun xs body => do + let us := info.levelParams.map mkLevelParam + let type ← mkEq (mkAppN (Lean.mkConst declName us) xs) body + let goal ← withNewMCtxDepth do + try + let goal ← mkFreshExprSyntheticOpaqueMVar type + let mvarId := goal.mvarId! + trace[Elab.definition.partialFixpoint] "mkUnfoldEq start:{mvarId}" + let mvarId ← deltaLHSUntilFix declName info.declNameNonRec mvarId + trace[Elab.definition.partialFixpoint] "mkUnfoldEq after deltaLHS:{mvarId}" + let mvarId ← rwFixEq mvarId + trace[Elab.definition.partialFixpoint] "mkUnfoldEq after rwFixEq:{mvarId}" + withAtLeastTransparency .all <| + withOptions (smartUnfolding.set · false) <| + mvarId.refl + trace[Elab.definition.partialFixpoint] "mkUnfoldEq rfl succeeded" + instantiateMVars goal + catch e => + throwError "failed to generate unfold theorem for '{declName}':\n{e.toMessageData}" + let type ← mkForallFVars xs type + let value ← mkLambdaFVars xs goal + let name := Name.str baseName unfoldThmSuffix + addDecl <| Declaration.thmDecl { + name, type, value + levelParams := info.levelParams + } + return name def getUnfoldFor? (declName : Name) : MetaM (Option Name) := do + let name := Name.str declName unfoldThmSuffix let env ← getEnv - Eqns.getUnfoldFor? declName fun _ => eqnInfoExt.find? env declName |>.map (·.toEqnInfoCore) + if env.contains name then return name + let some info := eqnInfoExt.find? env declName | return none + return some (← mkUnfoldEq declName info) + +def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do + if let some _ := eqnInfoExt.find? (← getEnv) declName then + mkEqns declName + else + return none builtin_initialize registerGetEqnsFn getEqnsFor? diff --git a/tests/lean/run/issue6786.lean b/tests/lean/run/issue6786.lean new file mode 100644 index 0000000000..98d8d760a5 --- /dev/null +++ b/tests/lean/run/issue6786.lean @@ -0,0 +1,101 @@ +def find42 : Nat → Bool + | 42 => true + | n => find42 (n + 1) +partial_fixpoint + +/-- +info: find42.eq_def (x✝ : Nat) : + find42 x✝ = + match x✝ with + | 42 => true + | n => find42 (n + 1) +-/ +#guard_msgs in +#check find42.eq_def + +/-- +info: equations: +theorem find42.eq_1 : find42 42 = true +theorem find42.eq_2 : ∀ (x : Nat), (x = 42 → False) → find42 x = find42 (x + 1) +-/ +#guard_msgs in +#print equations find42 + +mutual +def find99 : Nat → Bool + | 99 => true + | n => find23 (n + 1) +partial_fixpoint +def find23 : Nat → Bool + | 23 => true + | n => find99 (n + 1) +partial_fixpoint +end + +/-- +info: find99.eq_def (x✝ : Nat) : + find99 x✝ = + match x✝ with + | 99 => true + | n => find23 (n + 1) +-/ +#guard_msgs in +#check find99.eq_def + +/-- +info: find23.eq_def (x✝ : Nat) : + find23 x✝ = + match x✝ with + | 23 => true + | n => find99 (n + 1) +-/ +#guard_msgs in +#check find23.eq_def + +/-- +info: equations: +theorem find99.eq_1 : find99 99 = true +theorem find99.eq_2 : ∀ (x : Nat), (x = 99 → False) → find99 x = find23 (x + 1) +-/ +#guard_msgs in +#print equations find99 + +/-- +info: equations: +theorem find23.eq_1 : find23 23 = true +theorem find23.eq_2 : ∀ (x : Nat), (x = 23 → False) → find23 x = find99 (x + 1) +-/ +#guard_msgs in +#print equations find23 + + +mutual +def g (i j : Nat) : Nat := + if i < 5 then 0 else + match j with + | Nat.zero => 1 + | Nat.succ j => h i j +partial_fixpoint + +def h (i j : Nat) : Nat := + match j with + | 0 => g i 0 + | Nat.succ j => g i j +partial_fixpoint +end + +/-- +info: equations: +theorem g.eq_1 : ∀ (i : Nat), g i Nat.zero = if i < 5 then 0 else 1 +theorem g.eq_2 : ∀ (i j_2 : Nat), g i j_2.succ = if i < 5 then 0 else h i j_2 +-/ +#guard_msgs in +#print equations g + +/-- +info: equations: +theorem h.eq_1 : ∀ (i : Nat), h i 0 = g i 0 +theorem h.eq_2 : ∀ (i j_2 : Nat), h i j_2.succ = g i j_2 +-/ +#guard_msgs in +#print equations h