feat: add [grind norm] and [grind unfold] attributes (#11776)
This PR adds the attributes `[grind norm]` and `[grind unfold]` for controlling the `grind` normalizer/preprocessor. The `norm` modifier instructs `grind` to use a theorem as a normalization rule. That is, the theorem is applied during the preprocessing step. This feature is meant for advanced users who understand how the preprocessor and `grind`'s search procedure interact with each other. New users can still benefit from this feature by restricting its use to theorems that completely eliminate a symbol from the goal. Example: ```lean theorem max_def : max n m = if n ≤ m then m else n ``` For a negative example, consider: ```lean opaque f : Int → Int → Int → Int theorem fax1 : f x 0 1 = 1 := sorry theorem fax2 : f 1 x 1 = 1 := sorry attribute [grind norm] fax1 attribute [grind =] fax2 example (h : c = 1) : f c 0 c = 1 := by grind -- fails ``` In this example, `fax1` is a normalization rule, but it is not applicable to the input goal since `f c 0 c` is not an instance of `f x 0 1`. However, `f c 0 c` matches the pattern `f 1 x 1` modulo the equality `c = 1`. Thus, `grind` instantiates `fax2` with `x := 0`, producing the equality `f 1 0 1 = 1`, which the normalizer simplifies to `True`. As a result, nothing useful is learned. In the future, we plan to include linters to automatically detect issues like these. Example: ```lean opaque f : Nat → Nat opaque g : Nat → Nat @[grind norm] axiom fax : f x = x + 2 @[grind norm ←] axiom fg : f x = g x example : f x ≥ 2 := by grind example : f x ≥ g x := by grind example : f x + g x ≥ 4 := by grind ``` The `unfold` modifier instructs `grind` to unfold the given definition during the preprocessing step. Example: ```lean @[grind unfold] def h (x : Nat) := 2 * x example : 6 ∣ 3*h x := by grind ```
This commit is contained in:
parent
f6a25b13b9
commit
a471f005d6
7 changed files with 112 additions and 27 deletions
|
|
@ -198,6 +198,55 @@ Given an application `f a₁ a₂ … aₙ`, when `funCC := true`,
|
|||
-/
|
||||
syntax grindFunCC := &"funCC"
|
||||
/--
|
||||
The `norm` modifier instructs `grind` to use a theorem as a normalization rule. That is,
|
||||
the theorem is applied during the preprocessing step.
|
||||
This feature is meant for advanced users who understand how the preprocessor and `grind`'s search
|
||||
procedure interact with each other.
|
||||
New users can still benefit from this feature by restricting its use to theorems that completely
|
||||
eliminate a symbol from the goal. Example:
|
||||
```
|
||||
theorem max_def : max n m = if n ≤ m then m else n
|
||||
```
|
||||
For a negative example, consider:
|
||||
```
|
||||
opaque f : Int → Int → Int → Int
|
||||
theorem fax1 : f x 0 1 = 1 := sorry
|
||||
theorem fax2 : f 1 x 1 = 1 := sorry
|
||||
attribute [grind norm] fax1
|
||||
attribute [grind =] fax2
|
||||
|
||||
example (h : c = 1) : f c 0 c = 1 := by
|
||||
grind -- fails
|
||||
```
|
||||
In this example, `fax1` is a normalization rule, but it is not applicable to the input goal since
|
||||
`f c 0 c` is not an instance of `f x 0 1`. However, `f c 0 c` matches the pattern `f 1 x 1` modulo
|
||||
the equality `c = 1`. Thus, `grind` instantiates `fax2` with `x := 0`, producing the equality
|
||||
`f 1 0 1 = 1`, which the normalizer simplifies to `True`. As a result, nothing useful is learned.
|
||||
In the future, we plan to include linters to automatically detect issues like these.
|
||||
Example:
|
||||
```
|
||||
opaque f : Nat → Nat
|
||||
opaque g : Nat → Nat
|
||||
|
||||
@[grind norm] axiom fax : f x = x + 2
|
||||
@[grind norm ←] axiom fg : f x = g x
|
||||
|
||||
example : f x ≥ 2 := by grind
|
||||
example : f x ≥ g x := by grind
|
||||
example : f x + g x ≥ 4 := by grind
|
||||
```
|
||||
-/
|
||||
syntax grindNorm := &"norm" (Tactic.simpPre <|> Tactic.simpPost)? patternIgnore("← " <|> "<- ")?
|
||||
/--
|
||||
The `unfold` modifier instructs `grind` to unfold the given definition during the preprocessing step.
|
||||
Example:
|
||||
```
|
||||
@[grind unfold] def h (x : Nat) := 2 * x
|
||||
example : 6 ∣ 3*h x := by grind
|
||||
```
|
||||
-/
|
||||
syntax grindUnfold := &"unfold"
|
||||
/--
|
||||
`symbol <prio>` sets the priority of a constant for `grind`’s pattern-selection
|
||||
procedure. `grind` prefers patterns that contain higher-priority symbols.
|
||||
Example:
|
||||
|
|
@ -224,7 +273,7 @@ syntax grindMod :=
|
|||
grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd
|
||||
<|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager
|
||||
<|> grindCases <|> grindIntro <|> grindExt <|> grindGen <|> grindSym <|> grindInj
|
||||
<|> grindFunCC <|> grindDef
|
||||
<|> grindFunCC <|> grindNorm <|> grindUnfold <|> grindDef
|
||||
|
||||
/--
|
||||
Marks a theorem or definition for use by the `grind` tactic.
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ where
|
|||
elabEMatchTheorem declName (.default false) minIndexable
|
||||
else
|
||||
return thms.toArray
|
||||
| .cases _ | .intro | .inj | .ext | .symbol _ | .funCC =>
|
||||
| .cases _ | .intro | .inj | .ext | .symbol _ | .funCC | .norm .. | .unfold =>
|
||||
throwError "invalid modifier"
|
||||
|
||||
def logAnchor (c : SplitInfo) : TermElabM Unit := do
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ def processTermParam (params : Grind.Params)
|
|||
checkNoRevert params
|
||||
let kind ← if let some mod := mod? then Grind.getAttrKindCore mod else pure .infer
|
||||
let kind ← match kind with
|
||||
| .ematch .user | .cases _ | .intro | .inj | .ext | .symbol _ | .funCC =>
|
||||
| .ematch .user | .cases _ | .intro | .inj | .ext | .symbol _ | .funCC | .norm .. | .unfold =>
|
||||
throwError "invalid `grind` parameter, only global declarations are allowed with this kind of modifier"
|
||||
| .ematch kind => pure kind
|
||||
| .infer => pure <| .default false
|
||||
|
|
@ -266,6 +266,8 @@ def processParam (params : Grind.Params)
|
|||
params := { params with symPrios := params.symPrios.insert declName prio }
|
||||
| .funCC =>
|
||||
params := params.insertFunCC declName
|
||||
| .norm .. => throwError "normalization theorems should be registered using the `@[grind norm]` attribute"
|
||||
| .unfold => throwError "declarations to be unfolded during normalization should be registered using the `@[grind unfold]` attribute"
|
||||
return params
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@ prelude
|
|||
public import Lean.Meta.Tactic.Grind.Injective
|
||||
public import Lean.Meta.Tactic.Grind.Cases
|
||||
public import Lean.Meta.Tactic.Grind.ExtAttr
|
||||
public import Lean.Meta.Tactic.Simp.Attr
|
||||
import Lean.ExtraModUses
|
||||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
builtin_initialize normExt : SimpExtension ← mkSimpExt
|
||||
|
||||
inductive AttrKind where
|
||||
| ematch (k : EMatchTheoremKind)
|
||||
| cases (eager : Bool)
|
||||
|
|
@ -21,6 +24,8 @@ inductive AttrKind where
|
|||
| symbol (prio : Nat)
|
||||
| inj
|
||||
| funCC
|
||||
| norm (post : Bool) (inv : Bool)
|
||||
| unfold
|
||||
|
||||
/-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/
|
||||
def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
|
||||
|
|
@ -47,6 +52,13 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
|
|||
| `(Parser.Attr.grindMod|ext) => return .ext
|
||||
| `(Parser.Attr.grindMod|inj) => return .inj
|
||||
| `(Parser.Attr.grindMod|funCC) => return .funCC
|
||||
| `(Parser.Attr.grindMod|norm) => return .norm true false
|
||||
| `(Parser.Attr.grindMod|norm ↑) => return .norm true false
|
||||
| `(Parser.Attr.grindMod|norm ↓) => return .norm (post := false) false
|
||||
| `(Parser.Attr.grindMod|norm ←) => return .norm true true
|
||||
| `(Parser.Attr.grindMod|norm ↑ ←) => return .norm true true
|
||||
| `(Parser.Attr.grindMod|norm ↓ ←) => return .norm (post := false) true
|
||||
| `(Parser.Attr.grindMod|unfold) => return .unfold
|
||||
| `(Parser.Attr.grindMod|symbol $prio:prio) =>
|
||||
let some prio := prio.raw.isNatLit? | throwErrorAt prio "priority expected"
|
||||
return .symbol prio
|
||||
|
|
@ -158,6 +170,15 @@ private def mkGrindAttr (attrName : Name) (minIndexable : Bool) (showInfo : Bool
|
|||
unless attrName == `grind do
|
||||
throwError "symbol priorities must be set using the default `[grind]` attribute"
|
||||
addSymbolPriorityAttr declName attrKind prio
|
||||
| .norm post inv =>
|
||||
unless attrName == `grind do
|
||||
throwError "normalizer must be set using the default `[grind]` attribute"
|
||||
addSimpTheorem normExt declName (post := post) (inv := inv) attrKind (eval_prio default)
|
||||
| .unfold =>
|
||||
unless attrName == `grind do
|
||||
throwError "declaration to unfold must be set using the default `[grind]` attribute"
|
||||
unless (← addDeclToUnfold normExt declName (post := false) (inv := false) (prio := eval_prio default) (attrKind := attrKind)) do
|
||||
throwError "cannot mark declaration to be unfolded by `grind`"
|
||||
| .cases eager => ext.addCasesAttr declName eager attrKind
|
||||
| .funCC => ext.addFunCCAttr declName attrKind
|
||||
| .ext => ext.addExtAttr declName attrKind
|
||||
|
|
|
|||
|
|
@ -18,11 +18,6 @@ import Init.Grind.Norm
|
|||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-
|
||||
TODO: group into a `grind` extension object
|
||||
-/
|
||||
builtin_initialize normExt : SimpExtension ← mkSimpExt
|
||||
|
||||
def registerNormTheorems (preDeclNames : Array Name) (postDeclNames : Array Name) : MetaM Unit := do
|
||||
let thms ← normExt.getTheorems
|
||||
unless thms.lemmaNames.isEmpty do
|
||||
|
|
|
|||
|
|
@ -4,15 +4,35 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Simp.Simproc
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Meta
|
||||
open Simp
|
||||
|
||||
/--
|
||||
Marks `declName` to be unfolded in the given `SimpExtension`.
|
||||
-/
|
||||
def addDeclToUnfold (ext : SimpExtension) (declName : Name) (post inv : Bool) (prio : Nat) (attrKind : AttributeKind) : MetaM Bool := do
|
||||
if getOriginalConstKind? (← getEnv) declName == some .defn then
|
||||
if inv then
|
||||
throwError m!"Invalid `←` modifier: `{.ofConstName declName}` is a declaration name to be unfolded"
|
||||
++ .note m!"The simplifier will automatically unfold definitions marked with the `[simp]` \
|
||||
attribute, but it will not \"refold\" them"
|
||||
if (← Simp.ignoreEquations declName) then
|
||||
ext.add (SimpEntry.toUnfold declName) attrKind
|
||||
else if let some eqns ← getEqnsFor? declName then
|
||||
for eqn in eqns do
|
||||
addSimpTheorem ext eqn post (inv := false) attrKind prio
|
||||
ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind
|
||||
if (← Simp.unfoldEvenWithEqns declName) then
|
||||
ext.add (SimpEntry.toUnfold declName) attrKind
|
||||
else
|
||||
ext.add (SimpEntry.toUnfold declName) attrKind
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension)
|
||||
(ref : Name := by exact decl_name%) : IO Unit :=
|
||||
registerBuiltinAttribute {
|
||||
|
|
@ -32,22 +52,7 @@ def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension)
|
|||
let prio ← getAttrParamOptPrio stx[3]
|
||||
if (← isProp info.sig.get.type) then
|
||||
addSimpTheorem ext declName post (inv := inv) attrKind prio
|
||||
else if getOriginalConstKind? (← getEnv) declName == some .defn then
|
||||
if inv then
|
||||
throwError m!"Invalid `←` modifier: `{.ofConstName declName}` is a declaration name to be unfolded"
|
||||
++ .note m!"The simplifier will automatically unfold definitions marked with the `[simp]` \
|
||||
attribute, but it will not \"refold\" them"
|
||||
if (← Simp.ignoreEquations declName) then
|
||||
ext.add (SimpEntry.toUnfold declName) attrKind
|
||||
else if let some eqns ← getEqnsFor? declName then
|
||||
for eqn in eqns do
|
||||
addSimpTheorem ext eqn post (inv := false) attrKind prio
|
||||
ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind
|
||||
if (← Simp.unfoldEvenWithEqns declName) then
|
||||
ext.add (SimpEntry.toUnfold declName) attrKind
|
||||
else
|
||||
ext.add (SimpEntry.toUnfold declName) attrKind
|
||||
else
|
||||
else unless (← addDeclToUnfold ext declName post inv prio attrKind) do
|
||||
throwError m!"Cannot add `simp` attribute to `{.ofConstName declName}`: It is not a proposition nor a definition (to unfold)"
|
||||
++ .note m!"The `[simp]` attribute can be added to lemmas that should be automatically used by the simplifier \
|
||||
and to definitions that the simplifier should automatically unfold"
|
||||
|
|
|
|||
13
tests/lean/run/grind_norm.lean
Normal file
13
tests/lean/run/grind_norm.lean
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
opaque f : Nat → Nat
|
||||
opaque g : Nat → Nat
|
||||
|
||||
@[grind norm] axiom fax : f x = x + 2
|
||||
@[grind norm ←] axiom fg : f x = g x
|
||||
|
||||
example : f x ≥ 2 := by grind
|
||||
example : f x ≥ g x := by grind
|
||||
example : f x + g x ≥ 4 := by grind
|
||||
|
||||
@[grind unfold] def h (x : Nat) := 2 * x
|
||||
|
||||
example : 2 ∣ h x := by grind
|
||||
Loading…
Add table
Reference in a new issue