diff --git a/src/Lean/Elab/PreDefinition/Structural/Eqns.lean b/src/Lean/Elab/PreDefinition/Structural/Eqns.lean index 62463832b5..76b8a36079 100644 --- a/src/Lean/Elab/PreDefinition/Structural/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Structural/Eqns.lean @@ -5,12 +5,40 @@ Authors: Leonardo de Moura -/ import Lean.Meta.Eqns import Lean.Meta.Tactic.Split +import Lean.Meta.Tactic.Apply import Lean.Elab.PreDefinition.Basic import Lean.Elab.PreDefinition.Structural.Basic -namespace Lean.Elab.Structural +namespace Lean.Elab open Meta +/-- Try to close goal using `rfl` with smart unfolding turned off. -/ +def tryURefl (mvarId : MVarId) : MetaM Bool := + withOptions (smartUnfolding.set . false) do + try applyRefl mvarId; return true catch _ => return false + +/-- Delta reduce the equation left-hand-side -/ +def deltaLHS (mvarId : MVarId) : MetaM MVarId := withMVarContext mvarId do + let target ← getMVarType' mvarId + let some (_, lhs, rhs) ← target.eq? | throwTacticEx `deltaLHS mvarId "equality expected" + let some lhs ← delta? lhs | throwTacticEx `deltaLHS mvarId "failed to delta reduce lhs" + replaceTargetDefEq mvarId (← mkEq lhs rhs) + +/-- Apply `whnfR` to lhs, return `none` if `lhs` was not modified -/ +def whnfReducibleLHS? (mvarId : MVarId) : MetaM (Option MVarId) := withMVarContext mvarId do + let target ← getMVarType' mvarId + let some (_, lhs, rhs) ← target.eq? | throwTacticEx `whnfReducibleLHS mvarId "equality expected" + let lhs' ← whnfR lhs + if lhs' != lhs then + return some (← replaceTargetDefEq mvarId (← mkEq lhs' rhs)) + else + return none + +def tryContradiction (mvarId : MVarId) : MetaM Bool := do + try contradiction mvarId { genDiseq := true }; return true catch _ => return false + +namespace Structural + structure EqnInfo where declName : Name levelParams : List Name @@ -25,13 +53,13 @@ private partial def expand : Expr → Expr | e => e private def expandRHS? (mvarId : MVarId) : MetaM (Option MVarId) := do - let target ← instantiateMVars (← getMVarType mvarId) + let target ← getMVarType' mvarId let some (_, lhs, rhs) ← target.eq? | return none unless rhs.isLet || rhs.isMData do return none return some (← replaceTargetDefEq mvarId (← mkEq lhs (expand rhs))) private def funext? (mvarId : MVarId) : MetaM (Option MVarId) := do - let target ← getMVarType mvarId + let target ← getMVarType' mvarId let some (_, lhs, rhs) ← target.eq? | return none unless rhs.isLambda do return none commitWhenSome? do @@ -53,7 +81,7 @@ private def simpMatch? (mvarId : MVarId) : MetaM (Option MVarId) := do if `recArgPos == 1` -/ private def matchRecArg (mvarId : MVarId) (recArgPos : Nat) : MetaM Bool := do - let target ← instantiateMVars (← getMVarType mvarId) + let target ← getMVarType' mvarId let some (_, lhs, rhs) ← target.eq? | return false let lhsArgs := lhs.getAppArgs if h : recArgPos < lhsArgs.size then @@ -74,7 +102,7 @@ private def matchRecArg (mvarId : MVarId) (recArgPos : Nat) : MetaM Bool := do return true -- conservative answer private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do - let target ← instantiateMVars (← getMVarType mvarId) + let target ← getMVarType' mvarId let fvarIds := collectFVars {} target |>.fvarSet.toArray let (_, mvarId) ← revert mvarId fvarIds let type ← instantiateMVars (← getMVarType mvarId) @@ -98,30 +126,50 @@ private partial def mkEqnTypes (mvarId : MVarId) : ReaderT EqnInfo (StateRefT (A private def mkBaseNameFor (env : Environment) (declName : Name) : Name := Lean.mkBaseNameFor env declName `eq_1 `_eqns -private def mkProof (type : Expr) : MetaM Expr := do - -- TODO - mkSorry type false +private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := + withNewMCtxDepth do + let main ← mkFreshExprSyntheticOpaqueMVar type + let (_, mvarId) ← intros main.mvarId! + unless (← tryURefl mvarId) do -- catch easy cases + go (← deltaLHS mvarId) + instantiateMVars main +where + go (mvarId : MVarId) : MetaM Unit := do + trace[Elab.definition.structural.eqns] "step\n{MessageData.ofGoal mvarId}" + if (← tryURefl mvarId) then + return () + else if (← tryContradiction mvarId) then + return () + else if let some mvarId ← whnfReducibleLHS? mvarId then + go mvarId + else if let some mvarId ← simpMatch? mvarId then + go mvarId + else if let some mvarIds ← casesOnStuckLHS? mvarId then + mvarIds.forM go + else + throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}" def mkEqns (info : EqnInfo) : MetaM (Array Name) := do withOptions (tactic.hygienic.set . false) do - lambdaTelescope info.value fun xs body => do + let eqnTypes ← withNewMCtxDepth <| lambdaTelescope info.value fun xs body => do let us := info.levelParams.map mkLevelParam let target ← mkEq (mkAppN (Lean.mkConst info.declName us) xs) body let goal ← mkFreshExprSyntheticOpaqueMVar target let (_, eqnTypes) ← mkEqnTypes goal.mvarId! |>.run info |>.run #[] - let baseName := mkBaseNameFor (← getEnv) info.declName - let mut thmNames := #[] - for i in [: eqnTypes.size] do - let type := eqnTypes[i] - trace[Elab.definition.structural.eqns] "{eqnTypes[i]}" - let name := baseName ++ (`eq).appendIndexAfter (i+1) - thmNames := thmNames.push name - let value ← mkProof type - addDecl <| Declaration.thmDecl { - name, type, value - levelParams := info.levelParams - } - return thmNames + return eqnTypes + let baseName := mkBaseNameFor (← getEnv) info.declName + let mut thmNames := #[] + for i in [: eqnTypes.size] do + let type := eqnTypes[i] + trace[Elab.definition.structural.eqns] "{eqnTypes[i]}" + let name := baseName ++ (`eq).appendIndexAfter (i+1) + thmNames := thmNames.push name + let value ← mkProof info.declName type + addDecl <| Declaration.thmDecl { + name, type, value + levelParams := info.levelParams + } + return thmNames builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension `structEqInfo @@ -151,4 +199,5 @@ builtin_initialize registerGetEqnsFn getEqnsFor? registerTraceClass `Elab.definition.structural.eqns -end Lean.Elab.Structural +end Structural +end Lean.Elab diff --git a/tests/lean/run/structuralEqns.lean b/tests/lean/run/structuralEqns.lean index 126711802f..b12ef176ed 100644 --- a/tests/lean/run/structuralEqns.lean +++ b/tests/lean/run/structuralEqns.lean @@ -1,7 +1,5 @@ import Lean -set_option trace.Elab.definition.structural.eqns true - open Lean open Lean.Meta def tst (declName : Name) : MetaM Unit := do