feat: generate proofs for structural (conditional) equality theorems
This commit is contained in:
parent
38090fa3c0
commit
d413aa1dc5
2 changed files with 72 additions and 25 deletions
|
|
@ -5,12 +5,40 @@ Authors: Leonardo de Moura
|
|||
-/
|
||||
import Lean.Meta.Eqns
|
||||
import Lean.Meta.Tactic.Split
|
||||
import Lean.Meta.Tactic.Apply
|
||||
import Lean.Elab.PreDefinition.Basic
|
||||
import Lean.Elab.PreDefinition.Structural.Basic
|
||||
|
||||
namespace Lean.Elab.Structural
|
||||
namespace Lean.Elab
|
||||
open Meta
|
||||
|
||||
/-- Try to close goal using `rfl` with smart unfolding turned off. -/
|
||||
def tryURefl (mvarId : MVarId) : MetaM Bool :=
|
||||
withOptions (smartUnfolding.set . false) do
|
||||
try applyRefl mvarId; return true catch _ => return false
|
||||
|
||||
/-- Delta reduce the equation left-hand-side -/
|
||||
def deltaLHS (mvarId : MVarId) : MetaM MVarId := withMVarContext mvarId do
|
||||
let target ← getMVarType' mvarId
|
||||
let some (_, lhs, rhs) ← target.eq? | throwTacticEx `deltaLHS mvarId "equality expected"
|
||||
let some lhs ← delta? lhs | throwTacticEx `deltaLHS mvarId "failed to delta reduce lhs"
|
||||
replaceTargetDefEq mvarId (← mkEq lhs rhs)
|
||||
|
||||
/-- Apply `whnfR` to lhs, return `none` if `lhs` was not modified -/
|
||||
def whnfReducibleLHS? (mvarId : MVarId) : MetaM (Option MVarId) := withMVarContext mvarId do
|
||||
let target ← getMVarType' mvarId
|
||||
let some (_, lhs, rhs) ← target.eq? | throwTacticEx `whnfReducibleLHS mvarId "equality expected"
|
||||
let lhs' ← whnfR lhs
|
||||
if lhs' != lhs then
|
||||
return some (← replaceTargetDefEq mvarId (← mkEq lhs' rhs))
|
||||
else
|
||||
return none
|
||||
|
||||
def tryContradiction (mvarId : MVarId) : MetaM Bool := do
|
||||
try contradiction mvarId { genDiseq := true }; return true catch _ => return false
|
||||
|
||||
namespace Structural
|
||||
|
||||
structure EqnInfo where
|
||||
declName : Name
|
||||
levelParams : List Name
|
||||
|
|
@ -25,13 +53,13 @@ private partial def expand : Expr → Expr
|
|||
| e => e
|
||||
|
||||
private def expandRHS? (mvarId : MVarId) : MetaM (Option MVarId) := do
|
||||
let target ← instantiateMVars (← getMVarType mvarId)
|
||||
let target ← getMVarType' mvarId
|
||||
let some (_, lhs, rhs) ← target.eq? | return none
|
||||
unless rhs.isLet || rhs.isMData do return none
|
||||
return some (← replaceTargetDefEq mvarId (← mkEq lhs (expand rhs)))
|
||||
|
||||
private def funext? (mvarId : MVarId) : MetaM (Option MVarId) := do
|
||||
let target ← getMVarType mvarId
|
||||
let target ← getMVarType' mvarId
|
||||
let some (_, lhs, rhs) ← target.eq? | return none
|
||||
unless rhs.isLambda do return none
|
||||
commitWhenSome? do
|
||||
|
|
@ -53,7 +81,7 @@ private def simpMatch? (mvarId : MVarId) : MetaM (Option MVarId) := do
|
|||
if `recArgPos == 1`
|
||||
-/
|
||||
private def matchRecArg (mvarId : MVarId) (recArgPos : Nat) : MetaM Bool := do
|
||||
let target ← instantiateMVars (← getMVarType mvarId)
|
||||
let target ← getMVarType' mvarId
|
||||
let some (_, lhs, rhs) ← target.eq? | return false
|
||||
let lhsArgs := lhs.getAppArgs
|
||||
if h : recArgPos < lhsArgs.size then
|
||||
|
|
@ -74,7 +102,7 @@ private def matchRecArg (mvarId : MVarId) (recArgPos : Nat) : MetaM Bool := do
|
|||
return true -- conservative answer
|
||||
|
||||
private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do
|
||||
let target ← instantiateMVars (← getMVarType mvarId)
|
||||
let target ← getMVarType' mvarId
|
||||
let fvarIds := collectFVars {} target |>.fvarSet.toArray
|
||||
let (_, mvarId) ← revert mvarId fvarIds
|
||||
let type ← instantiateMVars (← getMVarType mvarId)
|
||||
|
|
@ -98,30 +126,50 @@ private partial def mkEqnTypes (mvarId : MVarId) : ReaderT EqnInfo (StateRefT (A
|
|||
private def mkBaseNameFor (env : Environment) (declName : Name) : Name :=
|
||||
Lean.mkBaseNameFor env declName `eq_1 `_eqns
|
||||
|
||||
private def mkProof (type : Expr) : MetaM Expr := do
|
||||
-- TODO
|
||||
mkSorry type false
|
||||
private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr :=
|
||||
withNewMCtxDepth do
|
||||
let main ← mkFreshExprSyntheticOpaqueMVar type
|
||||
let (_, mvarId) ← intros main.mvarId!
|
||||
unless (← tryURefl mvarId) do -- catch easy cases
|
||||
go (← deltaLHS mvarId)
|
||||
instantiateMVars main
|
||||
where
|
||||
go (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Elab.definition.structural.eqns] "step\n{MessageData.ofGoal mvarId}"
|
||||
if (← tryURefl mvarId) then
|
||||
return ()
|
||||
else if (← tryContradiction mvarId) then
|
||||
return ()
|
||||
else if let some mvarId ← whnfReducibleLHS? mvarId then
|
||||
go mvarId
|
||||
else if let some mvarId ← simpMatch? mvarId then
|
||||
go mvarId
|
||||
else if let some mvarIds ← casesOnStuckLHS? mvarId then
|
||||
mvarIds.forM go
|
||||
else
|
||||
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
|
||||
|
||||
def mkEqns (info : EqnInfo) : MetaM (Array Name) := do
|
||||
withOptions (tactic.hygienic.set . false) do
|
||||
lambdaTelescope info.value fun xs body => do
|
||||
let eqnTypes ← withNewMCtxDepth <| lambdaTelescope info.value fun xs body => do
|
||||
let us := info.levelParams.map mkLevelParam
|
||||
let target ← mkEq (mkAppN (Lean.mkConst info.declName us) xs) body
|
||||
let goal ← mkFreshExprSyntheticOpaqueMVar target
|
||||
let (_, eqnTypes) ← mkEqnTypes goal.mvarId! |>.run info |>.run #[]
|
||||
let baseName := mkBaseNameFor (← getEnv) info.declName
|
||||
let mut thmNames := #[]
|
||||
for i in [: eqnTypes.size] do
|
||||
let type := eqnTypes[i]
|
||||
trace[Elab.definition.structural.eqns] "{eqnTypes[i]}"
|
||||
let name := baseName ++ (`eq).appendIndexAfter (i+1)
|
||||
thmNames := thmNames.push name
|
||||
let value ← mkProof type
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name, type, value
|
||||
levelParams := info.levelParams
|
||||
}
|
||||
return thmNames
|
||||
return eqnTypes
|
||||
let baseName := mkBaseNameFor (← getEnv) info.declName
|
||||
let mut thmNames := #[]
|
||||
for i in [: eqnTypes.size] do
|
||||
let type := eqnTypes[i]
|
||||
trace[Elab.definition.structural.eqns] "{eqnTypes[i]}"
|
||||
let name := baseName ++ (`eq).appendIndexAfter (i+1)
|
||||
thmNames := thmNames.push name
|
||||
let value ← mkProof info.declName type
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name, type, value
|
||||
levelParams := info.levelParams
|
||||
}
|
||||
return thmNames
|
||||
|
||||
builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension `structEqInfo
|
||||
|
||||
|
|
@ -151,4 +199,5 @@ builtin_initialize
|
|||
registerGetEqnsFn getEqnsFor?
|
||||
registerTraceClass `Elab.definition.structural.eqns
|
||||
|
||||
end Lean.Elab.Structural
|
||||
end Structural
|
||||
end Lean.Elab
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import Lean
|
||||
|
||||
set_option trace.Elab.definition.structural.eqns true
|
||||
|
||||
open Lean
|
||||
open Lean.Meta
|
||||
def tst (declName : Name) : MetaM Unit := do
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue