fix: more robust equational theorems generation for partial_fixpoint (#6790)
This PR fixes an issue with the generation of equational theorems from `partial_fixpoint` when case-splitting is necessary. Fixes #6786.
This commit is contained in:
parent
3aea0fd810
commit
3418d6db8e
4 changed files with 264 additions and 113 deletions
|
|
@ -308,6 +308,115 @@ def whnfReducibleLHS? (mvarId : MVarId) : MetaM (Option MVarId) := mvarId.withCo
|
|||
def tryContradiction (mvarId : MVarId) : MetaM Bool := do
|
||||
mvarId.contradictionCore { genDiseq := true }
|
||||
|
||||
/--
|
||||
Returns the type of the unfold theorem, as the starting point for calculating the equational
|
||||
types.
|
||||
-/
|
||||
private def unfoldThmType (declName : Name) : MetaM Expr := do
|
||||
if let some unfoldThm ← getUnfoldEqnFor? declName (nonRec := false) then
|
||||
let info ← getConstInfo unfoldThm
|
||||
pure info.type
|
||||
else
|
||||
let info ← getConstInfoDefn declName
|
||||
let us := info.levelParams.map mkLevelParam
|
||||
lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
|
||||
let type ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
|
||||
mkForallFVars xs type
|
||||
|
||||
private def unfoldLHS (declName : Name) (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
|
||||
if let some unfoldThm ← getUnfoldEqnFor? declName (nonRec := false) then
|
||||
-- Recursive definition: Use unfolding lemma
|
||||
let mut mvarId := mvarId
|
||||
let target ← mvarId.getType'
|
||||
let some (_, lhs, rhs) := target.eq? | throwError "unfoldLHS: Unexpected target {target}"
|
||||
unless lhs.isAppOf declName do throwError "unfoldLHS: Unexpected LHS {lhs}"
|
||||
let h := mkAppN (.const unfoldThm lhs.getAppFn.constLevels!) lhs.getAppArgs
|
||||
let some (_, _, lhsNew) := (← inferType h).eq? | unreachable!
|
||||
let targetNew ← mkEq lhsNew rhs
|
||||
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew
|
||||
mvarId.assign (← mkEqTrans h mvarNew)
|
||||
return mvarNew.mvarId!
|
||||
else
|
||||
-- Else use delta reduction
|
||||
deltaLHS mvarId
|
||||
|
||||
private partial def mkEqnProof (declName : Name) (type : Expr) : MetaM Expr := do
|
||||
trace[Elab.definition.eqns] "proving: {type}"
|
||||
withNewMCtxDepth do
|
||||
let main ← mkFreshExprSyntheticOpaqueMVar type
|
||||
let (_, mvarId) ← main.mvarId!.intros
|
||||
-- Try rfl before deltaLHS to avoid `id` checkpoints in the proof, which would make
|
||||
-- the lemma ineligible for dsimp
|
||||
unless ← withAtLeastTransparency .all (tryURefl mvarId) do
|
||||
go (← unfoldLHS declName mvarId)
|
||||
instantiateMVars main
|
||||
where
|
||||
/--
|
||||
The core loop of proving an equation. Assumes that the function call on the left-hand side has
|
||||
already been unfolded, using whatever method applies to the current function definition strategy.
|
||||
|
||||
Currently used for non-recursive functions and partial fixpoints; maybe later well-founded
|
||||
recursion and structural recursion can and should use this too.
|
||||
-/
|
||||
go (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Elab.definition.eqns] "step\n{MessageData.ofGoal mvarId}"
|
||||
if ← withAtLeastTransparency .all (tryURefl mvarId) then
|
||||
return ()
|
||||
else if (← tryContradiction mvarId) then
|
||||
return ()
|
||||
else if let some mvarId ← simpMatch? mvarId then
|
||||
go mvarId
|
||||
else if let some mvarId ← simpIf? mvarId then
|
||||
go mvarId
|
||||
else if let some mvarId ← whnfReducibleLHS? mvarId then
|
||||
go mvarId
|
||||
else
|
||||
let ctx ← Simp.mkContext (config := { dsimp := false })
|
||||
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
|
||||
| TacticResultCNM.closed => return ()
|
||||
| TacticResultCNM.modified mvarId => go mvarId
|
||||
| TacticResultCNM.noChange =>
|
||||
if let some mvarIds ← casesOnStuckLHS? mvarId then
|
||||
mvarIds.forM go
|
||||
else if let some mvarIds ← splitTarget? mvarId then
|
||||
mvarIds.forM go
|
||||
else
|
||||
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
|
||||
|
||||
|
||||
/--
|
||||
Generate equations for `declName`.
|
||||
|
||||
This unfolds the function application on the LHS (using an unfold theorem, if present, or else by
|
||||
delta-reduction), calculates the types for the equational theorems using `mkEqnTypes`, and then
|
||||
proves them using `mkEqnProof`.
|
||||
|
||||
This is currently used for non-recursive functions and for functions defined by partial_fixpoint.
|
||||
-/
|
||||
def mkEqns (declName : Name) : MetaM (Array Name) := do
|
||||
let info ← getConstInfoDefn declName
|
||||
let us := info.levelParams.map mkLevelParam
|
||||
withOptions (tactic.hygienic.set · false) do
|
||||
let target ← unfoldThmType declName
|
||||
let eqnTypes ← withNewMCtxDepth <|
|
||||
forallTelescope (cleanupAnnotations := true) target fun xs target => do
|
||||
let goal ← mkFreshExprSyntheticOpaqueMVar target
|
||||
withReducible do
|
||||
mkEqnTypes #[] goal.mvarId!
|
||||
let mut thmNames := #[]
|
||||
for h : i in [: eqnTypes.size] do
|
||||
let type := eqnTypes[i]
|
||||
trace[Elab.definition.eqns] "eqnType[{i}]: {eqnTypes[i]}"
|
||||
let name := (Name.str declName eqnThmSuffixBase).appendIndexAfter (i+1)
|
||||
thmNames := thmNames.push name
|
||||
let value ← mkEqnProof declName type
|
||||
let (type, value) ← removeUnusedEqnHypotheses type value
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name, type, value
|
||||
levelParams := info.levelParams
|
||||
}
|
||||
return thmNames
|
||||
|
||||
/--
|
||||
Auxiliary method for `mkUnfoldEq`. The structure is based on `mkEqnTypes`.
|
||||
`mvarId` is the goal to be proved. It is a goal of the form
|
||||
|
|
|
|||
|
|
@ -33,71 +33,12 @@ private def mkSimpleEqThm (declName : Name) (suffix := Name.mkSimple unfoldThmSu
|
|||
else
|
||||
return none
|
||||
|
||||
private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
|
||||
trace[Elab.definition.eqns] "proving: {type}"
|
||||
withNewMCtxDepth do
|
||||
let main ← mkFreshExprSyntheticOpaqueMVar type
|
||||
let (_, mvarId) ← main.mvarId!.intros
|
||||
let rec go (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Elab.definition.eqns] "step\n{MessageData.ofGoal mvarId}"
|
||||
if ← withAtLeastTransparency .all (tryURefl mvarId) then
|
||||
return ()
|
||||
else if (← tryContradiction mvarId) then
|
||||
return ()
|
||||
else if let some mvarId ← simpMatch? mvarId then
|
||||
go mvarId
|
||||
else if let some mvarId ← simpIf? mvarId then
|
||||
go mvarId
|
||||
else if let some mvarId ← whnfReducibleLHS? mvarId then
|
||||
go mvarId
|
||||
else
|
||||
let ctx ← Simp.mkContext (config := { dsimp := false })
|
||||
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
|
||||
| TacticResultCNM.closed => return ()
|
||||
| TacticResultCNM.modified mvarId => go mvarId
|
||||
| TacticResultCNM.noChange =>
|
||||
if let some mvarIds ← casesOnStuckLHS? mvarId then
|
||||
mvarIds.forM go
|
||||
else if let some mvarIds ← splitTarget? mvarId then
|
||||
mvarIds.forM go
|
||||
else
|
||||
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
|
||||
|
||||
-- Try rfl before deltaLHS to avoid `id` checkpoints in the proof, which would make
|
||||
-- the lemma ineligible for dsimp
|
||||
unless ← withAtLeastTransparency .all (tryURefl mvarId) do
|
||||
go (← deltaLHS mvarId)
|
||||
instantiateMVars main
|
||||
|
||||
def mkEqns (declName : Name) (info : DefinitionVal) : MetaM (Array Name) :=
|
||||
withOptions (tactic.hygienic.set · false) do
|
||||
let baseName := declName
|
||||
let eqnTypes ← withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
|
||||
let us := info.levelParams.map mkLevelParam
|
||||
let target ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
|
||||
let goal ← mkFreshExprSyntheticOpaqueMVar target
|
||||
withReducible do
|
||||
mkEqnTypes #[] goal.mvarId!
|
||||
let mut thmNames := #[]
|
||||
for h : i in [: eqnTypes.size] do
|
||||
let type := eqnTypes[i]
|
||||
trace[Elab.definition.eqns] "eqnType[{i}]: {eqnTypes[i]}"
|
||||
let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1)
|
||||
thmNames := thmNames.push name
|
||||
let value ← mkProof declName type
|
||||
let (type, value) ← removeUnusedEqnHypotheses type value
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name, type, value
|
||||
levelParams := info.levelParams
|
||||
}
|
||||
return thmNames
|
||||
|
||||
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
|
||||
if (← isRecursiveDefinition declName) then
|
||||
return none
|
||||
if let some (.defnInfo info) := (← getEnv).find? declName then
|
||||
if (← getEnv).contains declName then
|
||||
if backward.eqns.nonrecursive.get (← getOptions) then
|
||||
mkEqns declName info
|
||||
mkEqns declName
|
||||
else
|
||||
let o ← mkSimpleEqThm declName
|
||||
return o.map (#[·])
|
||||
|
|
|
|||
|
|
@ -23,6 +23,18 @@ structure EqnInfo extends EqnInfoCore where
|
|||
fixedPrefixSize : Nat
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension
|
||||
|
||||
def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat) : MetaM Unit := do
|
||||
preDefs.forM fun preDef => ensureEqnReservedNamesAvailable preDef.declName
|
||||
unless preDefs.all fun p => p.kind.isTheorem do
|
||||
unless (← preDefs.allM fun p => isProp p.type) do
|
||||
let declNames := preDefs.map (·.declName)
|
||||
modifyEnv fun env =>
|
||||
preDefs.foldl (init := env) fun env preDef =>
|
||||
eqnInfoExt.insert env preDef.declName { preDef with
|
||||
declNames, declNameNonRec, fixedPrefixSize }
|
||||
|
||||
private def deltaLHSUntilFix (declName declNameNonRec : Name) (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
|
||||
let target ← mvarId.getType'
|
||||
let some (_, lhs, rhs) := target.eq? | throwTacticEx `deltaLHSUntilFix mvarId "equality expected"
|
||||
|
|
@ -53,62 +65,50 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
|
|||
mvarId.assign (← mkEqTrans h mvarNew)
|
||||
return mvarNew.mvarId!
|
||||
|
||||
private partial def mkProof (declName : Name) (declNameNonRec : Name) (type : Expr) : MetaM Expr := do
|
||||
trace[Elab.definition.partialFixpoint] "proving: {type}"
|
||||
withNewMCtxDepth do
|
||||
let main ← mkFreshExprSyntheticOpaqueMVar type
|
||||
let (_, mvarId) ← main.mvarId!.intros
|
||||
let mvarId ← deltaLHSUntilFix declName declNameNonRec mvarId
|
||||
let mvarId ← rwFixEq mvarId
|
||||
if ← withAtLeastTransparency .all (tryURefl mvarId) then
|
||||
instantiateMVars main
|
||||
else
|
||||
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
|
||||
|
||||
def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
|
||||
/-- Generate the "unfold" lemma for `declName`. -/
|
||||
def mkUnfoldEq (declName : Name) (info : EqnInfo) : MetaM Name := withLCtx {} {} do
|
||||
withOptions (tactic.hygienic.set · false) do
|
||||
let baseName := declName
|
||||
let eqnTypes ← withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
|
||||
let us := info.levelParams.map mkLevelParam
|
||||
let target ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
|
||||
let goal ← mkFreshExprSyntheticOpaqueMVar target
|
||||
withReducible do
|
||||
mkEqnTypes info.declNames goal.mvarId!
|
||||
let mut thmNames := #[]
|
||||
for h : i in [: eqnTypes.size] do
|
||||
let type := eqnTypes[i]
|
||||
trace[Elab.definition.partialFixpoint] "{eqnTypes[i]}"
|
||||
let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1)
|
||||
thmNames := thmNames.push name
|
||||
let value ← mkProof declName info.declNameNonRec type
|
||||
let (type, value) ← removeUnusedEqnHypotheses type value
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name, type, value
|
||||
levelParams := info.levelParams
|
||||
}
|
||||
return thmNames
|
||||
|
||||
builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension
|
||||
|
||||
def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat) : MetaM Unit := do
|
||||
preDefs.forM fun preDef => ensureEqnReservedNamesAvailable preDef.declName
|
||||
unless preDefs.all fun p => p.kind.isTheorem do
|
||||
unless (← preDefs.allM fun p => isProp p.type) do
|
||||
let declNames := preDefs.map (·.declName)
|
||||
modifyEnv fun env =>
|
||||
preDefs.foldl (init := env) fun env preDef =>
|
||||
eqnInfoExt.insert env preDef.declName { preDef with
|
||||
declNames, declNameNonRec, fixedPrefixSize }
|
||||
|
||||
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
|
||||
if let some info := eqnInfoExt.find? (← getEnv) declName then
|
||||
mkEqns declName info
|
||||
else
|
||||
return none
|
||||
let baseName := declName
|
||||
lambdaTelescope info.value fun xs body => do
|
||||
let us := info.levelParams.map mkLevelParam
|
||||
let type ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
|
||||
let goal ← withNewMCtxDepth do
|
||||
try
|
||||
let goal ← mkFreshExprSyntheticOpaqueMVar type
|
||||
let mvarId := goal.mvarId!
|
||||
trace[Elab.definition.partialFixpoint] "mkUnfoldEq start:{mvarId}"
|
||||
let mvarId ← deltaLHSUntilFix declName info.declNameNonRec mvarId
|
||||
trace[Elab.definition.partialFixpoint] "mkUnfoldEq after deltaLHS:{mvarId}"
|
||||
let mvarId ← rwFixEq mvarId
|
||||
trace[Elab.definition.partialFixpoint] "mkUnfoldEq after rwFixEq:{mvarId}"
|
||||
withAtLeastTransparency .all <|
|
||||
withOptions (smartUnfolding.set · false) <|
|
||||
mvarId.refl
|
||||
trace[Elab.definition.partialFixpoint] "mkUnfoldEq rfl succeeded"
|
||||
instantiateMVars goal
|
||||
catch e =>
|
||||
throwError "failed to generate unfold theorem for '{declName}':\n{e.toMessageData}"
|
||||
let type ← mkForallFVars xs type
|
||||
let value ← mkLambdaFVars xs goal
|
||||
let name := Name.str baseName unfoldThmSuffix
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name, type, value
|
||||
levelParams := info.levelParams
|
||||
}
|
||||
return name
|
||||
|
||||
def getUnfoldFor? (declName : Name) : MetaM (Option Name) := do
|
||||
let name := Name.str declName unfoldThmSuffix
|
||||
let env ← getEnv
|
||||
Eqns.getUnfoldFor? declName fun _ => eqnInfoExt.find? env declName |>.map (·.toEqnInfoCore)
|
||||
if env.contains name then return name
|
||||
let some info := eqnInfoExt.find? env declName | return none
|
||||
return some (← mkUnfoldEq declName info)
|
||||
|
||||
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
|
||||
if let some _ := eqnInfoExt.find? (← getEnv) declName then
|
||||
mkEqns declName
|
||||
else
|
||||
return none
|
||||
|
||||
builtin_initialize
|
||||
registerGetEqnsFn getEqnsFor?
|
||||
|
|
|
|||
101
tests/lean/run/issue6786.lean
Normal file
101
tests/lean/run/issue6786.lean
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
def find42 : Nat → Bool
|
||||
| 42 => true
|
||||
| n => find42 (n + 1)
|
||||
partial_fixpoint
|
||||
|
||||
/--
|
||||
info: find42.eq_def (x✝ : Nat) :
|
||||
find42 x✝ =
|
||||
match x✝ with
|
||||
| 42 => true
|
||||
| n => find42 (n + 1)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check find42.eq_def
|
||||
|
||||
/--
|
||||
info: equations:
|
||||
theorem find42.eq_1 : find42 42 = true
|
||||
theorem find42.eq_2 : ∀ (x : Nat), (x = 42 → False) → find42 x = find42 (x + 1)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print equations find42
|
||||
|
||||
mutual
|
||||
def find99 : Nat → Bool
|
||||
| 99 => true
|
||||
| n => find23 (n + 1)
|
||||
partial_fixpoint
|
||||
def find23 : Nat → Bool
|
||||
| 23 => true
|
||||
| n => find99 (n + 1)
|
||||
partial_fixpoint
|
||||
end
|
||||
|
||||
/--
|
||||
info: find99.eq_def (x✝ : Nat) :
|
||||
find99 x✝ =
|
||||
match x✝ with
|
||||
| 99 => true
|
||||
| n => find23 (n + 1)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check find99.eq_def
|
||||
|
||||
/--
|
||||
info: find23.eq_def (x✝ : Nat) :
|
||||
find23 x✝ =
|
||||
match x✝ with
|
||||
| 23 => true
|
||||
| n => find99 (n + 1)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check find23.eq_def
|
||||
|
||||
/--
|
||||
info: equations:
|
||||
theorem find99.eq_1 : find99 99 = true
|
||||
theorem find99.eq_2 : ∀ (x : Nat), (x = 99 → False) → find99 x = find23 (x + 1)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print equations find99
|
||||
|
||||
/--
|
||||
info: equations:
|
||||
theorem find23.eq_1 : find23 23 = true
|
||||
theorem find23.eq_2 : ∀ (x : Nat), (x = 23 → False) → find23 x = find99 (x + 1)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print equations find23
|
||||
|
||||
|
||||
mutual
|
||||
def g (i j : Nat) : Nat :=
|
||||
if i < 5 then 0 else
|
||||
match j with
|
||||
| Nat.zero => 1
|
||||
| Nat.succ j => h i j
|
||||
partial_fixpoint
|
||||
|
||||
def h (i j : Nat) : Nat :=
|
||||
match j with
|
||||
| 0 => g i 0
|
||||
| Nat.succ j => g i j
|
||||
partial_fixpoint
|
||||
end
|
||||
|
||||
/--
|
||||
info: equations:
|
||||
theorem g.eq_1 : ∀ (i : Nat), g i Nat.zero = if i < 5 then 0 else 1
|
||||
theorem g.eq_2 : ∀ (i j_2 : Nat), g i j_2.succ = if i < 5 then 0 else h i j_2
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print equations g
|
||||
|
||||
/--
|
||||
info: equations:
|
||||
theorem h.eq_1 : ∀ (i : Nat), h i 0 = g i 0
|
||||
theorem h.eq_2 : ∀ (i j_2 : Nat), h i j_2.succ = g i j_2
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print equations h
|
||||
Loading…
Add table
Reference in a new issue