diff --git a/src/Init/Tactics.lean b/src/Init/Tactics.lean index 16f301c5bc..f79235cadf 100644 --- a/src/Init/Tactics.lean +++ b/src/Init/Tactics.lean @@ -2423,6 +2423,16 @@ defining the thing you are rewriting. -/ syntax (name := method_specs_simp) "method_specs_simp" (Tactic.simpPre <|> Tactic.simpPost)? patternIgnore("← " <|> "<- ")? (ppSpace prio)? : attr +/-- +Register a theorem as a rewrite rule for `cbv` evaluation of a given definition. + +You can instruct `cbv` to rewrite the lemma from right-to-left: +```lean +@[cbv_eval ←] theorem my_thm : rhs = lhs := ... +``` +-/ +syntax (name := cbv_eval) "cbv_eval" patternIgnore("← " <|> "<- ")? (ppSpace ident)? : attr + /-- The possible `norm_cast` kinds: `elim`, `move`, or `squash`. -/ syntax normCastLabel := &"elim" <|> &"move" <|> &"squash" diff --git a/src/Lean/Meta/Tactic/Cbv/CbvEvalExt.lean b/src/Lean/Meta/Tactic/Cbv/CbvEvalExt.lean index 63443bee30..3f37085667 100644 --- a/src/Lean/Meta/Tactic/Cbv/CbvEvalExt.lean +++ b/src/Lean/Meta/Tactic/Cbv/CbvEvalExt.lean @@ -9,6 +9,8 @@ public import Lean.Data.NameMap public import Lean.ScopedEnvExtension public import Lean.Elab.InfoTree public import Lean.Meta.Sym.Simp.Theorems +import Lean.Meta.Tactic.AuxLemma +import Lean.Meta.AppBuilder public section namespace Lean.Meta.Sym.Simp @@ -29,18 +31,24 @@ structure CbvEvalEntry where thm : Theorem deriving BEq, Inhabited -def mkCbvTheoremFromConst (declName : Name) : MetaM CbvEvalEntry := do +def mkCbvTheoremFromConst (declName : Name) (inv : Bool := false) : MetaM CbvEvalEntry := do let cinfo ← getConstVal declName let us := cinfo.levelParams.map mkLevelParam let val := mkConst declName us let type ← inferType val unless (← isProp type) do throwError "{val} is not a theorem and thus cannot be marked with `cbv_eval` attribute" - let fnName ← forallTelescope type fun _ body => do - let some (_, lhs, _) := body.eq? | throwError "The conclusion {type} of theorem {val} is not an equality" - let appFn := lhs.getAppFn - let some constName := appFn.constName? | throwError "The left-hand side of a theorem {val} is not an application of a constant" - return constName - let thm ← mkTheoremFromDecl declName + let (fnName, thmDeclName) ← forallTelescope type fun xs body => do + let some (_, lhs, rhs) := body.eq? | throwError "The conclusion {type} of theorem {val} is not an equality" + let matchSide := if inv then rhs else lhs + let some constName := matchSide.getAppFn.constName? + | throwError "The rewrite side of theorem {val} is not an application of a constant" + let mut thmDeclName := declName + if inv then + let invType ← mkForallFVars xs (← mkEq rhs lhs) + let invVal ← mkLambdaFVars xs (← mkEqSymm (mkAppN val xs)) + thmDeclName ← mkAuxLemma (kind? := `_cbv_eval) cinfo.levelParams invType invVal + return (constName, thmDeclName) + let thm ← mkTheoremFromDecl thmDeclName return ⟨fnName, thm⟩ structure CbvEvalState where @@ -74,8 +82,9 @@ builtin_initialize name := `cbv_eval descr := "Register a theorem as a rewrite rule for `cbv` evaluation of a given definition." applicationTime := AttributeApplicationTime.afterCompilation - add := fun lemmaName _ kind => do - let (entry, _) ← MetaM.run (mkCbvTheoremFromConst lemmaName) {} + add := fun lemmaName stx kind => do + let inv := !stx[1].isNone + let (entry, _) ← MetaM.run (mkCbvTheoremFromConst lemmaName (inv := inv)) {} cbvEvalExt.add entry kind } diff --git a/tests/lean/run/cbv_eval_inv.lean b/tests/lean/run/cbv_eval_inv.lean new file mode 100644 index 0000000000..c869a1cafe --- /dev/null +++ b/tests/lean/run/cbv_eval_inv.lean @@ -0,0 +1,59 @@ +import Std +set_option cbv.warning false + +-- Basic test: inverted cbv_eval attribute +-- The theorem `42 = myConst` with ← becomes `myConst = 42` +-- so cbv can rewrite `myConst` to `42` +@[cbv_opaque] def myConst : Nat := 42 + +@[cbv_eval ←] theorem myConst_eq : 42 = myConst := by rfl + +example : myConst = 42 := by + conv => + lhs + cbv + +-- Test with a function application on the RHS +def myAdd (a b : Nat) : Nat := a + b + +@[cbv_opaque] def myAddAlias (a b : Nat) : Nat := myAdd a b + +-- The theorem `myAdd a b = myAddAlias a b` with ← becomes `myAddAlias a b = myAdd a b` +-- so cbv can rewrite `myAddAlias a b` to `myAdd a b`, which it can then evaluate +@[cbv_eval ←] theorem myAddAlias_eq (a b : Nat) : myAdd a b = myAddAlias a b := by + unfold myAddAlias; rfl + +example : myAddAlias 2 3 = 5 := by + conv => + lhs + cbv + +-- Test with <- syntax (alternative arrow) +@[cbv_opaque] def myConst2 : Nat := 100 + +@[cbv_eval <-] theorem myConst2_eq : 100 = myConst2 := by rfl + +example : myConst2 = 100 := by + conv => + lhs + cbv + +-- Test that non-inverted cbv_eval still works +@[cbv_opaque] def myConst3 : Nat := 7 + +@[cbv_eval] theorem myConst3_eq : myConst3 = 7 := by rfl + +example : 7 = 7 := by + conv => + lhs + cbv + +-- Test with the optional ident argument (backward compatibility) +@[cbv_opaque] def myFn (n : Nat) : Nat := n + 1 + +@[cbv_eval myFn] theorem myFn_zero : myFn 0 = 1 := by rfl + +example : 1 = 1 := by + conv => + lhs + cbv diff --git a/tests/pkg/cbv_attr/CbvAttr/InvertedLocalTheorem.lean b/tests/pkg/cbv_attr/CbvAttr/InvertedLocalTheorem.lean new file mode 100644 index 0000000000..c823538231 --- /dev/null +++ b/tests/pkg/cbv_attr/CbvAttr/InvertedLocalTheorem.lean @@ -0,0 +1,12 @@ +module + +set_option cbv.warning false + +@[cbv_opaque] public def f7 (x : Nat) := + x + 1 + +private axiom myAx : x + 1 = f7 x + +@[local cbv_eval ←] public theorem f7_spec : x + 1 = f7 x := myAx + +example : f7 1 = 2 := by conv => lhs; cbv diff --git a/tests/pkg/cbv_attr/CbvAttr/InvertedTheorem.lean b/tests/pkg/cbv_attr/CbvAttr/InvertedTheorem.lean new file mode 100644 index 0000000000..af852aa4f2 --- /dev/null +++ b/tests/pkg/cbv_attr/CbvAttr/InvertedTheorem.lean @@ -0,0 +1,12 @@ +module + +set_option cbv.warning false + +@[cbv_opaque] public def f6 (x : Nat) := + x + 1 + +private axiom myAx : x + 1 = f6 x + +@[cbv_eval ←] public theorem f6_spec : x + 1 = f6 x := myAx + +example : f6 1 = 2 := by conv => lhs; cbv diff --git a/tests/pkg/cbv_attr/CbvAttr/Tst.lean b/tests/pkg/cbv_attr/CbvAttr/Tst.lean index 1ac989bcc1..2cd31d48f7 100644 --- a/tests/pkg/cbv_attr/CbvAttr/Tst.lean +++ b/tests/pkg/cbv_attr/CbvAttr/Tst.lean @@ -4,6 +4,8 @@ import CbvAttr.PubliclyVisibleTheorem import CbvAttr.PublicFunctionLocalTheorem import CbvAttr.PublicFunction import CbvAttr.PublicFunctionPrivateTheorem +import CbvAttr.InvertedTheorem +import CbvAttr.InvertedLocalTheorem set_option cbv.warning false @@ -39,3 +41,15 @@ error: unsolved goals -/ #guard_msgs in example : f5 1 = 2 := by conv => lhs; cbv + +/- Inverted public theorem: `x + 1 = f6 x` with ← becomes `f6 x = x + 1` -/ +example : f6 1 = 2 := by conv => lhs; cbv + +/- Inverted local theorem should not be visible across modules -/ + +/-- +error: unsolved goals +⊢ f7 1 = 2 +-/ +#guard_msgs in +example : f7 1 = 2 := by conv => lhs; cbv