feat: generate proofs for structural (conditional) equality theorems

This commit is contained in:
Leonardo de Moura 2021-09-17 18:56:21 -07:00
parent 38090fa3c0
commit d413aa1dc5
2 changed files with 72 additions and 25 deletions

View file

@ -5,12 +5,40 @@ Authors: Leonardo de Moura
-/
import Lean.Meta.Eqns
import Lean.Meta.Tactic.Split
import Lean.Meta.Tactic.Apply
import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.Structural.Basic
namespace Lean.Elab.Structural
namespace Lean.Elab
open Meta
/-- Try to close goal using `rfl` with smart unfolding turned off. -/
def tryURefl (mvarId : MVarId) : MetaM Bool :=
withOptions (smartUnfolding.set . false) do
try applyRefl mvarId; return true catch _ => return false
/-- Delta reduce the equation left-hand-side -/
def deltaLHS (mvarId : MVarId) : MetaM MVarId := withMVarContext mvarId do
let target ← getMVarType' mvarId
let some (_, lhs, rhs) ← target.eq? | throwTacticEx `deltaLHS mvarId "equality expected"
let some lhs ← delta? lhs | throwTacticEx `deltaLHS mvarId "failed to delta reduce lhs"
replaceTargetDefEq mvarId (← mkEq lhs rhs)
/-- Apply `whnfR` to lhs, return `none` if `lhs` was not modified -/
def whnfReducibleLHS? (mvarId : MVarId) : MetaM (Option MVarId) := withMVarContext mvarId do
let target ← getMVarType' mvarId
let some (_, lhs, rhs) ← target.eq? | throwTacticEx `whnfReducibleLHS mvarId "equality expected"
let lhs' ← whnfR lhs
if lhs' != lhs then
return some (← replaceTargetDefEq mvarId (← mkEq lhs' rhs))
else
return none
def tryContradiction (mvarId : MVarId) : MetaM Bool := do
try contradiction mvarId { genDiseq := true }; return true catch _ => return false
namespace Structural
structure EqnInfo where
declName : Name
levelParams : List Name
@ -25,13 +53,13 @@ private partial def expand : Expr → Expr
| e => e
private def expandRHS? (mvarId : MVarId) : MetaM (Option MVarId) := do
let target ← instantiateMVars (← getMVarType mvarId)
let target ← getMVarType' mvarId
let some (_, lhs, rhs) ← target.eq? | return none
unless rhs.isLet || rhs.isMData do return none
return some (← replaceTargetDefEq mvarId (← mkEq lhs (expand rhs)))
private def funext? (mvarId : MVarId) : MetaM (Option MVarId) := do
let target ← getMVarType mvarId
let target ← getMVarType' mvarId
let some (_, lhs, rhs) ← target.eq? | return none
unless rhs.isLambda do return none
commitWhenSome? do
@ -53,7 +81,7 @@ private def simpMatch? (mvarId : MVarId) : MetaM (Option MVarId) := do
if `recArgPos == 1`
-/
private def matchRecArg (mvarId : MVarId) (recArgPos : Nat) : MetaM Bool := do
let target ← instantiateMVars (← getMVarType mvarId)
let target ← getMVarType' mvarId
let some (_, lhs, rhs) ← target.eq? | return false
let lhsArgs := lhs.getAppArgs
if h : recArgPos < lhsArgs.size then
@ -74,7 +102,7 @@ private def matchRecArg (mvarId : MVarId) (recArgPos : Nat) : MetaM Bool := do
return true -- conservative answer
private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do
let target ← instantiateMVars (← getMVarType mvarId)
let target ← getMVarType' mvarId
let fvarIds := collectFVars {} target |>.fvarSet.toArray
let (_, mvarId) ← revert mvarId fvarIds
let type ← instantiateMVars (← getMVarType mvarId)
@ -98,30 +126,50 @@ private partial def mkEqnTypes (mvarId : MVarId) : ReaderT EqnInfo (StateRefT (A
private def mkBaseNameFor (env : Environment) (declName : Name) : Name :=
Lean.mkBaseNameFor env declName `eq_1 `_eqns
private def mkProof (type : Expr) : MetaM Expr := do
-- TODO
mkSorry type false
private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr :=
withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) ← intros main.mvarId!
unless (← tryURefl mvarId) do -- catch easy cases
go (← deltaLHS mvarId)
instantiateMVars main
where
go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.structural.eqns] "step\n{MessageData.ofGoal mvarId}"
if (← tryURefl mvarId) then
return ()
else if (← tryContradiction mvarId) then
return ()
else if let some mvarId ← whnfReducibleLHS? mvarId then
go mvarId
else if let some mvarId ← simpMatch? mvarId then
go mvarId
else if let some mvarIds ← casesOnStuckLHS? mvarId then
mvarIds.forM go
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
def mkEqns (info : EqnInfo) : MetaM (Array Name) := do
withOptions (tactic.hygienic.set . false) do
lambdaTelescope info.value fun xs body => do
let eqnTypes ← withNewMCtxDepth <| lambdaTelescope info.value fun xs body => do
let us := info.levelParams.map mkLevelParam
let target ← mkEq (mkAppN (Lean.mkConst info.declName us) xs) body
let goal ← mkFreshExprSyntheticOpaqueMVar target
let (_, eqnTypes) ← mkEqnTypes goal.mvarId! |>.run info |>.run #[]
let baseName := mkBaseNameFor (← getEnv) info.declName
let mut thmNames := #[]
for i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.structural.eqns] "{eqnTypes[i]}"
let name := baseName ++ (`eq).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value ← mkProof type
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return thmNames
return eqnTypes
let baseName := mkBaseNameFor (← getEnv) info.declName
let mut thmNames := #[]
for i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.structural.eqns] "{eqnTypes[i]}"
let name := baseName ++ (`eq).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value ← mkProof info.declName type
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return thmNames
builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension `structEqInfo
@ -151,4 +199,5 @@ builtin_initialize
registerGetEqnsFn getEqnsFor?
registerTraceClass `Elab.definition.structural.eqns
end Lean.Elab.Structural
end Structural
end Lean.Elab

View file

@ -1,7 +1,5 @@
import Lean
set_option trace.Elab.definition.structural.eqns true
open Lean
open Lean.Meta
def tst (declName : Name) : MetaM Unit := do