From a471f005d61c05485a800a283f2f408890ac2f81 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 22 Dec 2025 19:54:35 -0800 Subject: [PATCH] feat: add `[grind norm]` and `[grind unfold]` attributes (#11776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds the attributes `[grind norm]` and `[grind unfold]` for controlling the `grind` normalizer/preprocessor. The `norm` modifier instructs `grind` to use a theorem as a normalization rule. That is, the theorem is applied during the preprocessing step. This feature is meant for advanced users who understand how the preprocessor and `grind`'s search procedure interact with each other. New users can still benefit from this feature by restricting its use to theorems that completely eliminate a symbol from the goal. Example: ```lean theorem max_def : max n m = if n ≤ m then m else n ``` For a negative example, consider: ```lean opaque f : Int → Int → Int → Int theorem fax1 : f x 0 1 = 1 := sorry theorem fax2 : f 1 x 1 = 1 := sorry attribute [grind norm] fax1 attribute [grind =] fax2 example (h : c = 1) : f c 0 c = 1 := by grind -- fails ``` In this example, `fax1` is a normalization rule, but it is not applicable to the input goal since `f c 0 c` is not an instance of `f x 0 1`. However, `f c 0 c` matches the pattern `f 1 x 1` modulo the equality `c = 1`. Thus, `grind` instantiates `fax2` with `x := 0`, producing the equality `f 1 0 1 = 1`, which the normalizer simplifies to `True`. As a result, nothing useful is learned. In the future, we plan to include linters to automatically detect issues like these. Example: ```lean opaque f : Nat → Nat opaque g : Nat → Nat @[grind norm] axiom fax : f x = x + 2 @[grind norm ←] axiom fg : f x = g x example : f x ≥ 2 := by grind example : f x ≥ g x := by grind example : f x + g x ≥ 4 := by grind ``` The `unfold` modifier instructs `grind` to unfold the given definition during the preprocessing step. Example: ```lean @[grind unfold] def h (x : Nat) := 2 * x example : 6 ∣ 3*h x := by grind ``` --- src/Init/Grind/Attr.lean | 51 ++++++++++++++++++- src/Lean/Elab/Tactic/Grind/BuiltinTactic.lean | 2 +- src/Lean/Elab/Tactic/Grind/Param.lean | 4 +- src/Lean/Meta/Tactic/Grind/Attr.lean | 21 ++++++++ src/Lean/Meta/Tactic/Grind/SimpUtil.lean | 5 -- src/Lean/Meta/Tactic/Simp/Attr.lean | 43 +++++++++------- tests/lean/run/grind_norm.lean | 13 +++++ 7 files changed, 112 insertions(+), 27 deletions(-) create mode 100644 tests/lean/run/grind_norm.lean diff --git a/src/Init/Grind/Attr.lean b/src/Init/Grind/Attr.lean index 621785e629..e494c914d2 100644 --- a/src/Init/Grind/Attr.lean +++ b/src/Init/Grind/Attr.lean @@ -198,6 +198,55 @@ Given an application `f a₁ a₂ … aₙ`, when `funCC := true`, -/ syntax grindFunCC := &"funCC" /-- +The `norm` modifier instructs `grind` to use a theorem as a normalization rule. That is, +the theorem is applied during the preprocessing step. +This feature is meant for advanced users who understand how the preprocessor and `grind`'s search +procedure interact with each other. +New users can still benefit from this feature by restricting its use to theorems that completely +eliminate a symbol from the goal. Example: +``` +theorem max_def : max n m = if n ≤ m then m else n +``` +For a negative example, consider: +``` +opaque f : Int → Int → Int → Int +theorem fax1 : f x 0 1 = 1 := sorry +theorem fax2 : f 1 x 1 = 1 := sorry +attribute [grind norm] fax1 +attribute [grind =] fax2 + +example (h : c = 1) : f c 0 c = 1 := by + grind -- fails +``` +In this example, `fax1` is a normalization rule, but it is not applicable to the input goal since +`f c 0 c` is not an instance of `f x 0 1`. However, `f c 0 c` matches the pattern `f 1 x 1` modulo +the equality `c = 1`. Thus, `grind` instantiates `fax2` with `x := 0`, producing the equality +`f 1 0 1 = 1`, which the normalizer simplifies to `True`. As a result, nothing useful is learned. +In the future, we plan to include linters to automatically detect issues like these. +Example: +``` +opaque f : Nat → Nat +opaque g : Nat → Nat + +@[grind norm] axiom fax : f x = x + 2 +@[grind norm ←] axiom fg : f x = g x + +example : f x ≥ 2 := by grind +example : f x ≥ g x := by grind +example : f x + g x ≥ 4 := by grind +``` +-/ +syntax grindNorm := &"norm" (Tactic.simpPre <|> Tactic.simpPost)? patternIgnore("← " <|> "<- ")? +/-- +The `unfold` modifier instructs `grind` to unfold the given definition during the preprocessing step. +Example: +``` +@[grind unfold] def h (x : Nat) := 2 * x +example : 6 ∣ 3*h x := by grind +``` +-/ +syntax grindUnfold := &"unfold" +/-- `symbol ` sets the priority of a constant for `grind`’s pattern-selection procedure. `grind` prefers patterns that contain higher-priority symbols. Example: @@ -224,7 +273,7 @@ syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt <|> grindGen <|> grindSym <|> grindInj - <|> grindFunCC <|> grindDef + <|> grindFunCC <|> grindNorm <|> grindUnfold <|> grindDef /-- Marks a theorem or definition for use by the `grind` tactic. diff --git a/src/Lean/Elab/Tactic/Grind/BuiltinTactic.lean b/src/Lean/Elab/Tactic/Grind/BuiltinTactic.lean index 23ae443b9f..cd0e58a6d4 100644 --- a/src/Lean/Elab/Tactic/Grind/BuiltinTactic.lean +++ b/src/Lean/Elab/Tactic/Grind/BuiltinTactic.lean @@ -245,7 +245,7 @@ where elabEMatchTheorem declName (.default false) minIndexable else return thms.toArray - | .cases _ | .intro | .inj | .ext | .symbol _ | .funCC => + | .cases _ | .intro | .inj | .ext | .symbol _ | .funCC | .norm .. | .unfold => throwError "invalid modifier" def logAnchor (c : SplitInfo) : TermElabM Unit := do diff --git a/src/Lean/Elab/Tactic/Grind/Param.lean b/src/Lean/Elab/Tactic/Grind/Param.lean index 61c564e68d..dc4bdebdc5 100644 --- a/src/Lean/Elab/Tactic/Grind/Param.lean +++ b/src/Lean/Elab/Tactic/Grind/Param.lean @@ -151,7 +151,7 @@ def processTermParam (params : Grind.Params) checkNoRevert params let kind ← if let some mod := mod? then Grind.getAttrKindCore mod else pure .infer let kind ← match kind with - | .ematch .user | .cases _ | .intro | .inj | .ext | .symbol _ | .funCC => + | .ematch .user | .cases _ | .intro | .inj | .ext | .symbol _ | .funCC | .norm .. | .unfold => throwError "invalid `grind` parameter, only global declarations are allowed with this kind of modifier" | .ematch kind => pure kind | .infer => pure <| .default false @@ -266,6 +266,8 @@ def processParam (params : Grind.Params) params := { params with symPrios := params.symPrios.insert declName prio } | .funCC => params := params.insertFunCC declName + | .norm .. => throwError "normalization theorems should be registered using the `@[grind norm]` attribute" + | .unfold => throwError "declarations to be unfolded during normalization should be registered using the `@[grind unfold]` attribute" return params /-- diff --git a/src/Lean/Meta/Tactic/Grind/Attr.lean b/src/Lean/Meta/Tactic/Grind/Attr.lean index 53e0ea39ff..f434df16c4 100644 --- a/src/Lean/Meta/Tactic/Grind/Attr.lean +++ b/src/Lean/Meta/Tactic/Grind/Attr.lean @@ -8,10 +8,13 @@ prelude public import Lean.Meta.Tactic.Grind.Injective public import Lean.Meta.Tactic.Grind.Cases public import Lean.Meta.Tactic.Grind.ExtAttr +public import Lean.Meta.Tactic.Simp.Attr import Lean.ExtraModUses public section namespace Lean.Meta.Grind +builtin_initialize normExt : SimpExtension ← mkSimpExt + inductive AttrKind where | ematch (k : EMatchTheoremKind) | cases (eager : Bool) @@ -21,6 +24,8 @@ inductive AttrKind where | symbol (prio : Nat) | inj | funCC + | norm (post : Bool) (inv : Bool) + | unfold /-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do @@ -47,6 +52,13 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do | `(Parser.Attr.grindMod|ext) => return .ext | `(Parser.Attr.grindMod|inj) => return .inj | `(Parser.Attr.grindMod|funCC) => return .funCC + | `(Parser.Attr.grindMod|norm) => return .norm true false + | `(Parser.Attr.grindMod|norm ↑) => return .norm true false + | `(Parser.Attr.grindMod|norm ↓) => return .norm (post := false) false + | `(Parser.Attr.grindMod|norm ←) => return .norm true true + | `(Parser.Attr.grindMod|norm ↑ ←) => return .norm true true + | `(Parser.Attr.grindMod|norm ↓ ←) => return .norm (post := false) true + | `(Parser.Attr.grindMod|unfold) => return .unfold | `(Parser.Attr.grindMod|symbol $prio:prio) => let some prio := prio.raw.isNatLit? | throwErrorAt prio "priority expected" return .symbol prio @@ -158,6 +170,15 @@ private def mkGrindAttr (attrName : Name) (minIndexable : Bool) (showInfo : Bool unless attrName == `grind do throwError "symbol priorities must be set using the default `[grind]` attribute" addSymbolPriorityAttr declName attrKind prio + | .norm post inv => + unless attrName == `grind do + throwError "normalizer must be set using the default `[grind]` attribute" + addSimpTheorem normExt declName (post := post) (inv := inv) attrKind (eval_prio default) + | .unfold => + unless attrName == `grind do + throwError "declaration to unfold must be set using the default `[grind]` attribute" + unless (← addDeclToUnfold normExt declName (post := false) (inv := false) (prio := eval_prio default) (attrKind := attrKind)) do + throwError "cannot mark declaration to be unfolded by `grind`" | .cases eager => ext.addCasesAttr declName eager attrKind | .funCC => ext.addFunCCAttr declName attrKind | .ext => ext.addExtAttr declName attrKind diff --git a/src/Lean/Meta/Tactic/Grind/SimpUtil.lean b/src/Lean/Meta/Tactic/Grind/SimpUtil.lean index 43c89e3812..d6cb676d83 100644 --- a/src/Lean/Meta/Tactic/Grind/SimpUtil.lean +++ b/src/Lean/Meta/Tactic/Grind/SimpUtil.lean @@ -18,11 +18,6 @@ import Init.Grind.Norm public section namespace Lean.Meta.Grind -/- -TODO: group into a `grind` extension object --/ -builtin_initialize normExt : SimpExtension ← mkSimpExt - def registerNormTheorems (preDeclNames : Array Name) (postDeclNames : Array Name) : MetaM Unit := do let thms ← normExt.getTheorems unless thms.lemmaNames.isEmpty do diff --git a/src/Lean/Meta/Tactic/Simp/Attr.lean b/src/Lean/Meta/Tactic/Simp/Attr.lean index ae08722188..6f32b30f1a 100644 --- a/src/Lean/Meta/Tactic/Simp/Attr.lean +++ b/src/Lean/Meta/Tactic/Simp/Attr.lean @@ -4,15 +4,35 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Lean.Meta.Tactic.Simp.Simproc - public section - namespace Lean.Meta open Simp +/-- +Marks `declName` to be unfolded in the given `SimpExtension`. +-/ +def addDeclToUnfold (ext : SimpExtension) (declName : Name) (post inv : Bool) (prio : Nat) (attrKind : AttributeKind) : MetaM Bool := do + if getOriginalConstKind? (← getEnv) declName == some .defn then + if inv then + throwError m!"Invalid `←` modifier: `{.ofConstName declName}` is a declaration name to be unfolded" + ++ .note m!"The simplifier will automatically unfold definitions marked with the `[simp]` \ + attribute, but it will not \"refold\" them" + if (← Simp.ignoreEquations declName) then + ext.add (SimpEntry.toUnfold declName) attrKind + else if let some eqns ← getEqnsFor? declName then + for eqn in eqns do + addSimpTheorem ext eqn post (inv := false) attrKind prio + ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind + if (← Simp.unfoldEvenWithEqns declName) then + ext.add (SimpEntry.toUnfold declName) attrKind + else + ext.add (SimpEntry.toUnfold declName) attrKind + return true + else + return false + def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension) (ref : Name := by exact decl_name%) : IO Unit := registerBuiltinAttribute { @@ -32,22 +52,7 @@ def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension) let prio ← getAttrParamOptPrio stx[3] if (← isProp info.sig.get.type) then addSimpTheorem ext declName post (inv := inv) attrKind prio - else if getOriginalConstKind? (← getEnv) declName == some .defn then - if inv then - throwError m!"Invalid `←` modifier: `{.ofConstName declName}` is a declaration name to be unfolded" - ++ .note m!"The simplifier will automatically unfold definitions marked with the `[simp]` \ - attribute, but it will not \"refold\" them" - if (← Simp.ignoreEquations declName) then - ext.add (SimpEntry.toUnfold declName) attrKind - else if let some eqns ← getEqnsFor? declName then - for eqn in eqns do - addSimpTheorem ext eqn post (inv := false) attrKind prio - ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind - if (← Simp.unfoldEvenWithEqns declName) then - ext.add (SimpEntry.toUnfold declName) attrKind - else - ext.add (SimpEntry.toUnfold declName) attrKind - else + else unless (← addDeclToUnfold ext declName post inv prio attrKind) do throwError m!"Cannot add `simp` attribute to `{.ofConstName declName}`: It is not a proposition nor a definition (to unfold)" ++ .note m!"The `[simp]` attribute can be added to lemmas that should be automatically used by the simplifier \ and to definitions that the simplifier should automatically unfold" diff --git a/tests/lean/run/grind_norm.lean b/tests/lean/run/grind_norm.lean new file mode 100644 index 0000000000..114af25a1b --- /dev/null +++ b/tests/lean/run/grind_norm.lean @@ -0,0 +1,13 @@ +opaque f : Nat → Nat +opaque g : Nat → Nat + +@[grind norm] axiom fax : f x = x + 2 +@[grind norm ←] axiom fg : f x = g x + +example : f x ≥ 2 := by grind +example : f x ≥ g x := by grind +example : f x + g x ≥ 4 := by grind + +@[grind unfold] def h (x : Nat) := 2 * x + +example : 2 ∣ h x := by grind