diff --git a/src/Lean/Elab/PreDefinition/Eqns.lean b/src/Lean/Elab/PreDefinition/Eqns.lean index f086bae871..3a3bbee865 100644 --- a/src/Lean/Elab/PreDefinition/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Eqns.lean @@ -166,9 +166,28 @@ structure UnfoldEqnExtState where builtin_initialize unfoldEqnExt : EnvExtension UnfoldEqnExtState ← registerEnvExtension (pure {}) -def mkUnfoldProof (declName : Name) (mvarId : MVarId) (eqs : Array Name) : MetaM Unit := do - -- TODO - throwError "failed to generate unfold theorem for '{declName}'\n{MessageData.ofGoal mvarId}" +private def tryEqns (mvarId : MVarId) (eqs : Array Name) : MetaM Bool := + eqs.anyM fun eq => commitWhen do + try + let subgoals ← apply mvarId (← mkConstWithFreshMVarLevels eq) + subgoals.allM assumptionCore + catch _ => + return false + +partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) (eqs : Array Name) : MetaM Unit := do + go mvarId +where + go (mvarId : MVarId) : MetaM Unit := do + if (← tryEqns mvarId eqs) then + return () + else if let some mvarId ← funext? mvarId then + go mvarId + else if let some mvarId ← simpMatch? mvarId then + go mvarId + else if let some mvarIds ← splitTarget? mvarId then + mvarIds.forM go + else + throwError "failed to generate unfold theorem for '{declName}'\n{MessageData.ofGoal mvarId}" def mkUnfoldEq (declName : Name) (info : EqnInfoCore) : MetaM Name := do let env ← getEnv diff --git a/tests/lean/run/structuralEqns.lean b/tests/lean/run/structuralEqns.lean index eef32516ec..edfdae6db5 100644 --- a/tests/lean/run/structuralEqns.lean +++ b/tests/lean/run/structuralEqns.lean @@ -3,11 +3,12 @@ import Lean open Lean open Lean.Meta def tst (declName : Name) : MetaM Unit := do - IO.println (← getEqnsFor? declName) + IO.println (← getUnfoldEqnFor? declName) #eval tst ``List.map #check @List.map._eq_1 #check @List.map._eq_2 +#check @List.map._unfold def foo (xs ys zs : List Nat) : List Nat := match (xs, ys) with @@ -22,6 +23,7 @@ def foo (xs ys zs : List Nat) : List Nat := #check foo._eq_1 #check foo._eq_2 +#check foo._unfold #eval tst ``foo @@ -38,6 +40,7 @@ def g : List Nat → List Nat → Nat #check g._eq_3 #check g._eq_4 #check g._eq_5 +#check g._unfold def h (xs : List Nat) (y : Nat) : Nat := match xs with @@ -51,6 +54,7 @@ def h (xs : List Nat) (y : Nat) : Nat := #check h._eq_1 #check h._eq_2 #check h._eq_3 +#check h._unfold def r (i j : Nat) : Nat := i + @@ -65,6 +69,7 @@ def r (i j : Nat) : Nat := #check r._eq_1 #check r._eq_2 #check r._eq_3 +#check r._unfold def bla (f g : α → α → α) (a : α) (i : α) (j : Nat) : α := f i <| @@ -79,3 +84,4 @@ def bla (f g : α → α → α) (a : α) (i : α) (j : Nat) : α := #check @bla._eq_1 #check @bla._eq_2 #check @bla._eq_3 +#check @bla._unfold diff --git a/tests/lean/run/structuralEqns2.lean b/tests/lean/run/structuralEqns2.lean index 9cbc1e6638..bb96df04cf 100644 --- a/tests/lean/run/structuralEqns2.lean +++ b/tests/lean/run/structuralEqns2.lean @@ -3,7 +3,7 @@ import Lean open Lean open Lean.Meta def tst (declName : Name) : MetaM Unit := do - IO.println (← getEqnsFor? declName) + IO.println (← getUnfoldEqnFor? declName) def g (i j : Nat) : Nat := if i < 5 then 0 else @@ -15,6 +15,7 @@ def g (i j : Nat) : Nat := #check g._eq_1 #check g._eq_2 #check g._eq_3 +#check g._unfold def h (i j : Nat) : Nat := let z := @@ -26,3 +27,4 @@ def h (i j : Nat) : Nat := #eval tst ``h #check h._eq_1 #check h._eq_2 +#check h._unfold diff --git a/tests/lean/run/structuralEqns3.lean b/tests/lean/run/structuralEqns3.lean index 2f97b830ee..ac72ee0522 100644 --- a/tests/lean/run/structuralEqns3.lean +++ b/tests/lean/run/structuralEqns3.lean @@ -3,7 +3,7 @@ import Lean open Lean open Lean.Meta def tst (declName : Name) : MetaM Unit := do - IO.println (← getEqnsFor? declName) + IO.println (← getUnfoldEqnFor? declName) inductive Wk: Nat -> Nat -> Type 0 where | id: Wk n n @@ -17,3 +17,4 @@ def wk_comp : Wk n m → Wk m l → Wk n l #check @wk_comp._eq_1 #check @wk_comp._eq_2 +#check @wk_comp._unfold diff --git a/tests/lean/run/wfEqns1.lean b/tests/lean/run/wfEqns1.lean index 4011d29193..9823a6e627 100644 --- a/tests/lean/run/wfEqns1.lean +++ b/tests/lean/run/wfEqns1.lean @@ -3,7 +3,7 @@ import Lean open Lean open Lean.Meta def tst (declName : Name) : MetaM Unit := do - IO.println (← getEqnsFor? declName) + IO.println (← getUnfoldEqnFor? declName) mutual def isEven : Nat → Bool @@ -25,3 +25,4 @@ decreasing_by #eval tst ``isEven #check @isEven._eq_1 #check @isEven._eq_2 +#check @isEven._unfold diff --git a/tests/lean/run/wfEqns2.lean b/tests/lean/run/wfEqns2.lean index 8d4c4d3ba3..51ebaef02f 100644 --- a/tests/lean/run/wfEqns2.lean +++ b/tests/lean/run/wfEqns2.lean @@ -3,7 +3,7 @@ import Lean open Lean open Lean.Meta def tst (declName : Name) : MetaM Unit := do - IO.println (← getEqnsFor? declName) + IO.println (← getUnfoldEqnFor? declName) mutual def g (i j : Nat) : Nat := @@ -34,6 +34,8 @@ decreasing_by #check g._eq_1 #check g._eq_2 #check g._eq_3 +#check g._unfold #eval tst ``h #check h._eq_1 #check h._eq_2 +#check h._unfold diff --git a/tests/lean/run/wfEqns3.lean b/tests/lean/run/wfEqns3.lean index 38f6ddffeb..11a4dfc102 100644 --- a/tests/lean/run/wfEqns3.lean +++ b/tests/lean/run/wfEqns3.lean @@ -3,7 +3,7 @@ import Lean open Lean open Lean.Meta def tst (declName : Name) : MetaM Unit := do - IO.println (← getEqnsFor? declName) + IO.println (← getUnfoldEqnFor? declName) def f (x : Nat) : Nat := if h : x = 0 then @@ -17,3 +17,4 @@ decreasing_by #eval tst ``f #check f._eq_1 +#check f._unfold diff --git a/tests/lean/run/wfEqns4.lean b/tests/lean/run/wfEqns4.lean index 8a5e8c5faf..8134b11939 100644 --- a/tests/lean/run/wfEqns4.lean +++ b/tests/lean/run/wfEqns4.lean @@ -3,7 +3,7 @@ import Lean open Lean open Lean.Meta def tst (declName : Name) : MetaM Unit := do - IO.println (← getEqnsFor? declName) + IO.println (← getUnfoldEqnFor? declName) mutual def f : Nat → α → α → α @@ -38,7 +38,10 @@ decreasing_by #eval tst ``f #check @f._eq_1 #check @f._eq_2 +#check @f._unfold + #eval tst ``h #check @h._eq_1 #check @h._eq_2 +#check @h._unfold