diff --git a/src/Lean/Elab/PreDefinition/WF/Eqns.lean b/src/Lean/Elab/PreDefinition/WF/Eqns.lean index 7fdd43540f..c7e7342100 100644 --- a/src/Lean/Elab/PreDefinition/WF/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/WF/Eqns.lean @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ import Lean.Meta.Tactic.Rewrite +import Lean.Meta.Tactic.Split import Lean.Elab.PreDefinition.Basic import Lean.Elab.PreDefinition.Eqns @@ -35,7 +36,29 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := withMVarContext mvarId d assignExprMVar mvarId (← mkEqTrans h mvarNew) return mvarNew.mvarId! -private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do +private def hasWellFoundedFix (e : Expr) : Bool := + Option.isSome <| e.find? (·.isConstOf ``WellFounded.fix) + +def simpMatchWF? (mvarId : MVarId) (info : EqnInfo) : MetaM (Option MVarId) := withMVarContext mvarId do + let target ← instantiateMVars (← getMVarType mvarId) + let targetNew ← Simp.main target (← Split.getSimpMatchContext) (methods := { pre }) + let mvarIdNew ← applySimpResultToTarget mvarId target targetNew + if mvarId != mvarIdNew then return some mvarIdNew else return none +where + pre (e : Expr) : SimpM Simp.Step := do + let some app ← matchMatcherApp? e | return Simp.Step.visit { expr := e } + if app.discrs.any hasWellFoundedFix then + -- TODO: try to fold `WellFounded.fix` occurrences in the discriminant + pure () + -- First try to reduce matcher + match (← reduceRecMatcher? e) with + | some e' => return Simp.Step.done { expr := e' } + | none => + match (← Simp.simpMatchCore? app e SplitIf.discharge?) with + | some r => return r + | none => return Simp.Step.visit { expr := e } + +private partial def mkProof (declName : Name) (info : EqnInfo) (type : Expr) : MetaM Expr := do trace[Elab.definition.wf.eqns] "proving: {type}" withNewMCtxDepth do let main ← mkFreshExprSyntheticOpaqueMVar type @@ -49,7 +72,7 @@ where return () else if (← tryContradiction mvarId) then return () - else if let some mvarId ← simpMatch? mvarId then + else if let some mvarId ← simpMatchWF? mvarId info then go mvarId else if let some mvarId ← simpIf? mvarId then go mvarId @@ -80,7 +103,7 @@ def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) := trace[Elab.definition.wf.eqns] "{eqnTypes[i]}" let name := baseName ++ (`_eq).appendIndexAfter (i+1) thmNames := thmNames.push name - let value ← mkProof declName type + let value ← mkProof declName info type let (type, value) ← removeUnusedEqnHypotheses type value addDecl <| Declaration.thmDecl { name, type, value diff --git a/src/Lean/Meta/Tactic/Split.lean b/src/Lean/Meta/Tactic/Split.lean index 1350300da5..4aaac1c7c5 100644 --- a/src/Lean/Meta/Tactic/Split.lean +++ b/src/Lean/Meta/Tactic/Split.lean @@ -9,7 +9,7 @@ import Lean.Meta.Tactic.Generalize namespace Lean.Meta namespace Split -private def getSimpMatchContext : MetaM Simp.Context := +def getSimpMatchContext : MetaM Simp.Context := return { simpTheorems := {} congrTheorems := (← getSimpCongrTheorems)