fix: more robust equational theorems generation for partial_fixpoint (#6790)

This PR fixes an issue with the generation of equational theorems from
`partial_fixpoint` when case-splitting is necessary. Fixes #6786.
This commit is contained in:
Joachim Breitner 2025-01-27 15:00:55 +01:00 committed by GitHub
parent 3aea0fd810
commit 3418d6db8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 264 additions and 113 deletions

View file

@ -308,6 +308,115 @@ def whnfReducibleLHS? (mvarId : MVarId) : MetaM (Option MVarId) := mvarId.withCo
def tryContradiction (mvarId : MVarId) : MetaM Bool := do
mvarId.contradictionCore { genDiseq := true }
/--
Returns the type of the unfold theorem, as the starting point for calculating the equational
types.
-/
private def unfoldThmType (declName : Name) : MetaM Expr := do
if let some unfoldThm ← getUnfoldEqnFor? declName (nonRec := false) then
let info ← getConstInfo unfoldThm
pure info.type
else
let info ← getConstInfoDefn declName
let us := info.levelParams.map mkLevelParam
lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
let type ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
mkForallFVars xs type
private def unfoldLHS (declName : Name) (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
if let some unfoldThm ← getUnfoldEqnFor? declName (nonRec := false) then
-- Recursive definition: Use unfolding lemma
let mut mvarId := mvarId
let target ← mvarId.getType'
let some (_, lhs, rhs) := target.eq? | throwError "unfoldLHS: Unexpected target {target}"
unless lhs.isAppOf declName do throwError "unfoldLHS: Unexpected LHS {lhs}"
let h := mkAppN (.const unfoldThm lhs.getAppFn.constLevels!) lhs.getAppArgs
let some (_, _, lhsNew) := (← inferType h).eq? | unreachable!
let targetNew ← mkEq lhsNew rhs
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew
mvarId.assign (← mkEqTrans h mvarNew)
return mvarNew.mvarId!
else
-- Else use delta reduction
deltaLHS mvarId
private partial def mkEqnProof (declName : Name) (type : Expr) : MetaM Expr := do
trace[Elab.definition.eqns] "proving: {type}"
withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) ← main.mvarId!.intros
-- Try rfl before deltaLHS to avoid `id` checkpoints in the proof, which would make
-- the lemma ineligible for dsimp
unless ← withAtLeastTransparency .all (tryURefl mvarId) do
go (← unfoldLHS declName mvarId)
instantiateMVars main
where
/--
The core loop of proving an equation. Assumes that the function call on the left-hand side has
already been unfolded, using whatever method applies to the current function definition strategy.
Currently used for non-recursive functions and partial fixpoints; maybe later well-founded
recursion and structural recursion can and should use this too.
-/
go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.eqns] "step\n{MessageData.ofGoal mvarId}"
if ← withAtLeastTransparency .all (tryURefl mvarId) then
return ()
else if (← tryContradiction mvarId) then
return ()
else if let some mvarId ← simpMatch? mvarId then
go mvarId
else if let some mvarId ← simpIf? mvarId then
go mvarId
else if let some mvarId ← whnfReducibleLHS? mvarId then
go mvarId
else
let ctx ← Simp.mkContext (config := { dsimp := false })
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.noChange =>
if let some mvarIds ← casesOnStuckLHS? mvarId then
mvarIds.forM go
else if let some mvarIds ← splitTarget? mvarId then
mvarIds.forM go
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
/--
Generate equations for `declName`.
This unfolds the function application on the LHS (using an unfold theorem, if present, or else by
delta-reduction), calculates the types for the equational theorems using `mkEqnTypes`, and then
proves them using `mkEqnProof`.
This is currently used for non-recursive functions and for functions defined by partial_fixpoint.
-/
def mkEqns (declName : Name) : MetaM (Array Name) := do
let info ← getConstInfoDefn declName
let us := info.levelParams.map mkLevelParam
withOptions (tactic.hygienic.set · false) do
let target ← unfoldThmType declName
let eqnTypes ← withNewMCtxDepth <|
forallTelescope (cleanupAnnotations := true) target fun xs target => do
let goal ← mkFreshExprSyntheticOpaqueMVar target
withReducible do
mkEqnTypes #[] goal.mvarId!
let mut thmNames := #[]
for h : i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.eqns] "eqnType[{i}]: {eqnTypes[i]}"
let name := (Name.str declName eqnThmSuffixBase).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value ← mkEqnProof declName type
let (type, value) ← removeUnusedEqnHypotheses type value
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return thmNames
/--
Auxiliary method for `mkUnfoldEq`. The structure is based on `mkEqnTypes`.
`mvarId` is the goal to be proved. It is a goal of the form

View file

@ -33,71 +33,12 @@ private def mkSimpleEqThm (declName : Name) (suffix := Name.mkSimple unfoldThmSu
else
return none
private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
trace[Elab.definition.eqns] "proving: {type}"
withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) ← main.mvarId!.intros
let rec go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.eqns] "step\n{MessageData.ofGoal mvarId}"
if ← withAtLeastTransparency .all (tryURefl mvarId) then
return ()
else if (← tryContradiction mvarId) then
return ()
else if let some mvarId ← simpMatch? mvarId then
go mvarId
else if let some mvarId ← simpIf? mvarId then
go mvarId
else if let some mvarId ← whnfReducibleLHS? mvarId then
go mvarId
else
let ctx ← Simp.mkContext (config := { dsimp := false })
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.noChange =>
if let some mvarIds ← casesOnStuckLHS? mvarId then
mvarIds.forM go
else if let some mvarIds ← splitTarget? mvarId then
mvarIds.forM go
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
-- Try rfl before deltaLHS to avoid `id` checkpoints in the proof, which would make
-- the lemma ineligible for dsimp
unless ← withAtLeastTransparency .all (tryURefl mvarId) do
go (← deltaLHS mvarId)
instantiateMVars main
def mkEqns (declName : Name) (info : DefinitionVal) : MetaM (Array Name) :=
withOptions (tactic.hygienic.set · false) do
let baseName := declName
let eqnTypes ← withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
let us := info.levelParams.map mkLevelParam
let target ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
let goal ← mkFreshExprSyntheticOpaqueMVar target
withReducible do
mkEqnTypes #[] goal.mvarId!
let mut thmNames := #[]
for h : i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.eqns] "eqnType[{i}]: {eqnTypes[i]}"
let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value ← mkProof declName type
let (type, value) ← removeUnusedEqnHypotheses type value
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return thmNames
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if (← isRecursiveDefinition declName) then
return none
if let some (.defnInfo info) := (← getEnv).find? declName then
if (← getEnv).contains declName then
if backward.eqns.nonrecursive.get (← getOptions) then
mkEqns declName info
mkEqns declName
else
let o ← mkSimpleEqThm declName
return o.map (#[·])

View file

@ -23,6 +23,18 @@ structure EqnInfo extends EqnInfoCore where
fixedPrefixSize : Nat
deriving Inhabited
builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension
def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat) : MetaM Unit := do
preDefs.forM fun preDef => ensureEqnReservedNamesAvailable preDef.declName
unless preDefs.all fun p => p.kind.isTheorem do
unless (← preDefs.allM fun p => isProp p.type) do
let declNames := preDefs.map (·.declName)
modifyEnv fun env =>
preDefs.foldl (init := env) fun env preDef =>
eqnInfoExt.insert env preDef.declName { preDef with
declNames, declNameNonRec, fixedPrefixSize }
private def deltaLHSUntilFix (declName declNameNonRec : Name) (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
let target ← mvarId.getType'
let some (_, lhs, rhs) := target.eq? | throwTacticEx `deltaLHSUntilFix mvarId "equality expected"
@ -53,62 +65,50 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
mvarId.assign (← mkEqTrans h mvarNew)
return mvarNew.mvarId!
private partial def mkProof (declName : Name) (declNameNonRec : Name) (type : Expr) : MetaM Expr := do
trace[Elab.definition.partialFixpoint] "proving: {type}"
withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) ← main.mvarId!.intros
let mvarId ← deltaLHSUntilFix declName declNameNonRec mvarId
let mvarId ← rwFixEq mvarId
if ← withAtLeastTransparency .all (tryURefl mvarId) then
instantiateMVars main
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
/-- Generate the "unfold" lemma for `declName`. -/
def mkUnfoldEq (declName : Name) (info : EqnInfo) : MetaM Name := withLCtx {} {} do
withOptions (tactic.hygienic.set · false) do
let baseName := declName
let eqnTypes ← withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
let us := info.levelParams.map mkLevelParam
let target ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
let goal ← mkFreshExprSyntheticOpaqueMVar target
withReducible do
mkEqnTypes info.declNames goal.mvarId!
let mut thmNames := #[]
for h : i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.partialFixpoint] "{eqnTypes[i]}"
let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value ← mkProof declName info.declNameNonRec type
let (type, value) ← removeUnusedEqnHypotheses type value
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return thmNames
builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension
def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat) : MetaM Unit := do
preDefs.forM fun preDef => ensureEqnReservedNamesAvailable preDef.declName
unless preDefs.all fun p => p.kind.isTheorem do
unless (← preDefs.allM fun p => isProp p.type) do
let declNames := preDefs.map (·.declName)
modifyEnv fun env =>
preDefs.foldl (init := env) fun env preDef =>
eqnInfoExt.insert env preDef.declName { preDef with
declNames, declNameNonRec, fixedPrefixSize }
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if let some info := eqnInfoExt.find? (← getEnv) declName then
mkEqns declName info
else
return none
let baseName := declName
lambdaTelescope info.value fun xs body => do
let us := info.levelParams.map mkLevelParam
let type ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
let goal ← withNewMCtxDepth do
try
let goal ← mkFreshExprSyntheticOpaqueMVar type
let mvarId := goal.mvarId!
trace[Elab.definition.partialFixpoint] "mkUnfoldEq start:{mvarId}"
let mvarId ← deltaLHSUntilFix declName info.declNameNonRec mvarId
trace[Elab.definition.partialFixpoint] "mkUnfoldEq after deltaLHS:{mvarId}"
let mvarId ← rwFixEq mvarId
trace[Elab.definition.partialFixpoint] "mkUnfoldEq after rwFixEq:{mvarId}"
withAtLeastTransparency .all <|
withOptions (smartUnfolding.set · false) <|
mvarId.refl
trace[Elab.definition.partialFixpoint] "mkUnfoldEq rfl succeeded"
instantiateMVars goal
catch e =>
throwError "failed to generate unfold theorem for '{declName}':\n{e.toMessageData}"
let type ← mkForallFVars xs type
let value ← mkLambdaFVars xs goal
let name := Name.str baseName unfoldThmSuffix
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return name
def getUnfoldFor? (declName : Name) : MetaM (Option Name) := do
let name := Name.str declName unfoldThmSuffix
let env ← getEnv
Eqns.getUnfoldFor? declName fun _ => eqnInfoExt.find? env declName |>.map (·.toEqnInfoCore)
if env.contains name then return name
let some info := eqnInfoExt.find? env declName | return none
return some (← mkUnfoldEq declName info)
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if let some _ := eqnInfoExt.find? (← getEnv) declName then
mkEqns declName
else
return none
builtin_initialize
registerGetEqnsFn getEqnsFor?

View file

@ -0,0 +1,101 @@
def find42 : Nat → Bool
| 42 => true
| n => find42 (n + 1)
partial_fixpoint
/--
info: find42.eq_def (x✝ : Nat) :
find42 x✝ =
match x✝ with
| 42 => true
| n => find42 (n + 1)
-/
#guard_msgs in
#check find42.eq_def
/--
info: equations:
theorem find42.eq_1 : find42 42 = true
theorem find42.eq_2 : ∀ (x : Nat), (x = 42 → False) → find42 x = find42 (x + 1)
-/
#guard_msgs in
#print equations find42
mutual
def find99 : Nat → Bool
| 99 => true
| n => find23 (n + 1)
partial_fixpoint
def find23 : Nat → Bool
| 23 => true
| n => find99 (n + 1)
partial_fixpoint
end
/--
info: find99.eq_def (x✝ : Nat) :
find99 x✝ =
match x✝ with
| 99 => true
| n => find23 (n + 1)
-/
#guard_msgs in
#check find99.eq_def
/--
info: find23.eq_def (x✝ : Nat) :
find23 x✝ =
match x✝ with
| 23 => true
| n => find99 (n + 1)
-/
#guard_msgs in
#check find23.eq_def
/--
info: equations:
theorem find99.eq_1 : find99 99 = true
theorem find99.eq_2 : ∀ (x : Nat), (x = 99 → False) → find99 x = find23 (x + 1)
-/
#guard_msgs in
#print equations find99
/--
info: equations:
theorem find23.eq_1 : find23 23 = true
theorem find23.eq_2 : ∀ (x : Nat), (x = 23 → False) → find23 x = find99 (x + 1)
-/
#guard_msgs in
#print equations find23
mutual
def g (i j : Nat) : Nat :=
if i < 5 then 0 else
match j with
| Nat.zero => 1
| Nat.succ j => h i j
partial_fixpoint
def h (i j : Nat) : Nat :=
match j with
| 0 => g i 0
| Nat.succ j => g i j
partial_fixpoint
end
/--
info: equations:
theorem g.eq_1 : ∀ (i : Nat), g i Nat.zero = if i < 5 then 0 else 1
theorem g.eq_2 : ∀ (i j_2 : Nat), g i j_2.succ = if i < 5 then 0 else h i j_2
-/
#guard_msgs in
#print equations g
/--
info: equations:
theorem h.eq_1 : ∀ (i : Nat), h i 0 = g i 0
theorem h.eq_2 : ∀ (i j_2 : Nat), h i j_2.succ = g i j_2
-/
#guard_msgs in
#print equations h