feat: add ← support to cbv_eval attribute (#12506)
This PR adds the ability to register theorems with the `cbv_eval` attribute in the reverse direction using the `←` modifier, mirroring the existing `simp` attribute behavior. When `@[cbv_eval ←]` is used, the equation `lhs = rhs` is inverted to `rhs = lhs`, allowing `cbv` to rewrite occurrences of `rhs` to `lhs`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
200f65649a
commit
424fbbdf26
6 changed files with 125 additions and 9 deletions
|
|
@ -2423,6 +2423,16 @@ defining the thing you are rewriting.
|
|||
-/
|
||||
syntax (name := method_specs_simp) "method_specs_simp" (Tactic.simpPre <|> Tactic.simpPost)? patternIgnore("← " <|> "<- ")? (ppSpace prio)? : attr
|
||||
|
||||
/--
|
||||
Register a theorem as a rewrite rule for `cbv` evaluation of a given definition.
|
||||
|
||||
You can instruct `cbv` to rewrite the lemma from right-to-left:
|
||||
```lean
|
||||
@[cbv_eval ←] theorem my_thm : rhs = lhs := ...
|
||||
```
|
||||
-/
|
||||
syntax (name := cbv_eval) "cbv_eval" patternIgnore("← " <|> "<- ")? (ppSpace ident)? : attr
|
||||
|
||||
/-- The possible `norm_cast` kinds: `elim`, `move`, or `squash`. -/
|
||||
syntax normCastLabel := &"elim" <|> &"move" <|> &"squash"
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ public import Lean.Data.NameMap
|
|||
public import Lean.ScopedEnvExtension
|
||||
public import Lean.Elab.InfoTree
|
||||
public import Lean.Meta.Sym.Simp.Theorems
|
||||
import Lean.Meta.Tactic.AuxLemma
|
||||
import Lean.Meta.AppBuilder
|
||||
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
|
|
@ -29,18 +31,24 @@ structure CbvEvalEntry where
|
|||
thm : Theorem
|
||||
deriving BEq, Inhabited
|
||||
|
||||
def mkCbvTheoremFromConst (declName : Name) : MetaM CbvEvalEntry := do
|
||||
def mkCbvTheoremFromConst (declName : Name) (inv : Bool := false) : MetaM CbvEvalEntry := do
|
||||
let cinfo ← getConstVal declName
|
||||
let us := cinfo.levelParams.map mkLevelParam
|
||||
let val := mkConst declName us
|
||||
let type ← inferType val
|
||||
unless (← isProp type) do throwError "{val} is not a theorem and thus cannot be marked with `cbv_eval` attribute"
|
||||
let fnName ← forallTelescope type fun _ body => do
|
||||
let some (_, lhs, _) := body.eq? | throwError "The conclusion {type} of theorem {val} is not an equality"
|
||||
let appFn := lhs.getAppFn
|
||||
let some constName := appFn.constName? | throwError "The left-hand side of a theorem {val} is not an application of a constant"
|
||||
return constName
|
||||
let thm ← mkTheoremFromDecl declName
|
||||
let (fnName, thmDeclName) ← forallTelescope type fun xs body => do
|
||||
let some (_, lhs, rhs) := body.eq? | throwError "The conclusion {type} of theorem {val} is not an equality"
|
||||
let matchSide := if inv then rhs else lhs
|
||||
let some constName := matchSide.getAppFn.constName?
|
||||
| throwError "The rewrite side of theorem {val} is not an application of a constant"
|
||||
let mut thmDeclName := declName
|
||||
if inv then
|
||||
let invType ← mkForallFVars xs (← mkEq rhs lhs)
|
||||
let invVal ← mkLambdaFVars xs (← mkEqSymm (mkAppN val xs))
|
||||
thmDeclName ← mkAuxLemma (kind? := `_cbv_eval) cinfo.levelParams invType invVal
|
||||
return (constName, thmDeclName)
|
||||
let thm ← mkTheoremFromDecl thmDeclName
|
||||
return ⟨fnName, thm⟩
|
||||
|
||||
structure CbvEvalState where
|
||||
|
|
@ -74,8 +82,9 @@ builtin_initialize
|
|||
name := `cbv_eval
|
||||
descr := "Register a theorem as a rewrite rule for `cbv` evaluation of a given definition."
|
||||
applicationTime := AttributeApplicationTime.afterCompilation
|
||||
add := fun lemmaName _ kind => do
|
||||
let (entry, _) ← MetaM.run (mkCbvTheoremFromConst lemmaName) {}
|
||||
add := fun lemmaName stx kind => do
|
||||
let inv := !stx[1].isNone
|
||||
let (entry, _) ← MetaM.run (mkCbvTheoremFromConst lemmaName (inv := inv)) {}
|
||||
cbvEvalExt.add entry kind
|
||||
}
|
||||
|
||||
|
|
|
|||
59
tests/lean/run/cbv_eval_inv.lean
Normal file
59
tests/lean/run/cbv_eval_inv.lean
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import Std
|
||||
set_option cbv.warning false
|
||||
|
||||
-- Basic test: inverted cbv_eval attribute
|
||||
-- The theorem `42 = myConst` with ← becomes `myConst = 42`
|
||||
-- so cbv can rewrite `myConst` to `42`
|
||||
@[cbv_opaque] def myConst : Nat := 42
|
||||
|
||||
@[cbv_eval ←] theorem myConst_eq : 42 = myConst := by rfl
|
||||
|
||||
example : myConst = 42 := by
|
||||
conv =>
|
||||
lhs
|
||||
cbv
|
||||
|
||||
-- Test with a function application on the RHS
|
||||
def myAdd (a b : Nat) : Nat := a + b
|
||||
|
||||
@[cbv_opaque] def myAddAlias (a b : Nat) : Nat := myAdd a b
|
||||
|
||||
-- The theorem `myAdd a b = myAddAlias a b` with ← becomes `myAddAlias a b = myAdd a b`
|
||||
-- so cbv can rewrite `myAddAlias a b` to `myAdd a b`, which it can then evaluate
|
||||
@[cbv_eval ←] theorem myAddAlias_eq (a b : Nat) : myAdd a b = myAddAlias a b := by
|
||||
unfold myAddAlias; rfl
|
||||
|
||||
example : myAddAlias 2 3 = 5 := by
|
||||
conv =>
|
||||
lhs
|
||||
cbv
|
||||
|
||||
-- Test with <- syntax (alternative arrow)
|
||||
@[cbv_opaque] def myConst2 : Nat := 100
|
||||
|
||||
@[cbv_eval <-] theorem myConst2_eq : 100 = myConst2 := by rfl
|
||||
|
||||
example : myConst2 = 100 := by
|
||||
conv =>
|
||||
lhs
|
||||
cbv
|
||||
|
||||
-- Test that non-inverted cbv_eval still works
|
||||
@[cbv_opaque] def myConst3 : Nat := 7
|
||||
|
||||
@[cbv_eval] theorem myConst3_eq : myConst3 = 7 := by rfl
|
||||
|
||||
example : 7 = 7 := by
|
||||
conv =>
|
||||
lhs
|
||||
cbv
|
||||
|
||||
-- Test with the optional ident argument (backward compatibility)
|
||||
@[cbv_opaque] def myFn (n : Nat) : Nat := n + 1
|
||||
|
||||
@[cbv_eval myFn] theorem myFn_zero : myFn 0 = 1 := by rfl
|
||||
|
||||
example : 1 = 1 := by
|
||||
conv =>
|
||||
lhs
|
||||
cbv
|
||||
12
tests/pkg/cbv_attr/CbvAttr/InvertedLocalTheorem.lean
Normal file
12
tests/pkg/cbv_attr/CbvAttr/InvertedLocalTheorem.lean
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
module
|
||||
|
||||
set_option cbv.warning false
|
||||
|
||||
@[cbv_opaque] public def f7 (x : Nat) :=
|
||||
x + 1
|
||||
|
||||
private axiom myAx : x + 1 = f7 x
|
||||
|
||||
@[local cbv_eval ←] public theorem f7_spec : x + 1 = f7 x := myAx
|
||||
|
||||
example : f7 1 = 2 := by conv => lhs; cbv
|
||||
12
tests/pkg/cbv_attr/CbvAttr/InvertedTheorem.lean
Normal file
12
tests/pkg/cbv_attr/CbvAttr/InvertedTheorem.lean
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
module
|
||||
|
||||
set_option cbv.warning false
|
||||
|
||||
@[cbv_opaque] public def f6 (x : Nat) :=
|
||||
x + 1
|
||||
|
||||
private axiom myAx : x + 1 = f6 x
|
||||
|
||||
@[cbv_eval ←] public theorem f6_spec : x + 1 = f6 x := myAx
|
||||
|
||||
example : f6 1 = 2 := by conv => lhs; cbv
|
||||
|
|
@ -4,6 +4,8 @@ import CbvAttr.PubliclyVisibleTheorem
|
|||
import CbvAttr.PublicFunctionLocalTheorem
|
||||
import CbvAttr.PublicFunction
|
||||
import CbvAttr.PublicFunctionPrivateTheorem
|
||||
import CbvAttr.InvertedTheorem
|
||||
import CbvAttr.InvertedLocalTheorem
|
||||
|
||||
set_option cbv.warning false
|
||||
|
||||
|
|
@ -39,3 +41,15 @@ error: unsolved goals
|
|||
-/
|
||||
#guard_msgs in
|
||||
example : f5 1 = 2 := by conv => lhs; cbv
|
||||
|
||||
/- Inverted public theorem: `x + 1 = f6 x` with ← becomes `f6 x = x + 1` -/
|
||||
example : f6 1 = 2 := by conv => lhs; cbv
|
||||
|
||||
/- Inverted local theorem should not be visible across modules -/
|
||||
|
||||
/--
|
||||
error: unsolved goals
|
||||
⊢ f7 1 = 2
|
||||
-/
|
||||
#guard_msgs in
|
||||
example : f7 1 = 2 := by conv => lhs; cbv
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue