feat: attribute [grind] (#6545)

This PR introduces the parametric attribute `[grind]` for annotating
theorems and definitions. It also replaces `[grind_eq]` with `[grind
=]`. For definitions, `[grind]` is equivalent to `[grind =]`.

The new attribute supports the following variants:

- **`[grind =]`**: Uses the left-hand side of the theorem's conclusion
as the pattern for E-matching.
- **`[grind =_]`**: Uses the right-hand side of the theorem's conclusion
as the pattern for E-matching.
- **`[grind _=_]`**: Creates two patterns. One for the left-hand side
and one for the right-hand side.
- **`[grind →]`**: Searches for (multi-)patterns in the theorem's
antecedents, stopping once a usable multi-pattern is found.
- **`[grind ←]`**: Searches for (multi-)patterns in the theorem's
conclusion, stopping once a usable multi-pattern is found.
- **`[grind]`**: Searches for (multi-)patterns in both the theorem's
conclusion and antecedents. It starts with the conclusion and stops once
a usable multi-pattern is found.

The `grind_pattern` command remains available for cases where these
attributes do not yield the desired result.
This commit is contained in:
Leonardo de Moura 2025-01-05 19:05:20 -08:00 committed by GitHub
parent 76f883b999
commit 2ed77f3b26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 380 additions and 40 deletions

View file

@ -6,6 +6,18 @@ Authors: Leonardo de Moura
prelude
import Init.Tactics
namespace Lean.Parser.Attr
syntax grindEq := "="
syntax grindEqBoth := "_=_"
syntax grindEqRhs := "=_"
syntax grindBwd := "←"
syntax grindFwd := "→"
syntax (name := grind) "grind" (grindEq <|> grindBwd <|> grindFwd <|> grindEqBoth <|> grindEqRhs)? : attr
end Lean.Parser.Attr
namespace Lean.Grind
/--
The configuration for `grind`.

View file

@ -34,6 +34,7 @@ builtin_initialize registerTraceClass `grind.eqc
builtin_initialize registerTraceClass `grind.internalize
builtin_initialize registerTraceClass `grind.ematch
builtin_initialize registerTraceClass `grind.ematch.pattern
builtin_initialize registerTraceClass `grind.ematch.pattern.search
builtin_initialize registerTraceClass `grind.ematch.instance
builtin_initialize registerTraceClass `grind.ematch.instance.assignment
builtin_initialize registerTraceClass `grind.issues

View file

@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Grind.Util
import Init.Grind.Tactics
import Lean.HeadIndex
import Lean.PrettyPrinter
import Lean.Util.FoldConsts
@ -218,16 +219,18 @@ private def getPatternFn? (pattern : Expr) : Option Expr :=
/--
Returns a bit-mask `mask` s.t. `mask[i]` is true if the the corresponding argument is
- a type or type former, or
- a type (that is not a proposition) or type former, or
- a proof, or
- an instance implicit argument
When `mask[i]`, we say the corresponding argument is a "support" argument.
-/
private def getPatternFunMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) := do
def getPatternSupportMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) := do
forallBoundedTelescope (← inferType f) numArgs fun xs _ => do
xs.mapM fun x => do
if (← isTypeFormer x <||> isProof x) then
if (← isProp x) then
return false
else if (← isTypeFormer x <||> isProof x) then
return true
else
return (← x.fvarId!.getDecl).binderInfo matches .instImplicit
@ -246,7 +249,7 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
assert! f.isConst || f.isFVar
saveSymbol f.toHeadIndex
let mut args := pattern.getAppArgs
let supportMask ← getPatternFunMask f args.size
let supportMask ← getPatternSupportMask f args.size
for i in [:args.size] do
let arg := args[i]!
let isSupport := supportMask[i]?.getD false
@ -278,6 +281,9 @@ def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex × Std.Hash
let (patterns, s) ← patterns.mapM go |>.run {}
return (patterns, s.symbols.toList, s.bvarsFound)
def normalizePattern (e : Expr) : M Expr := do
go e
end NormalizePattern
/--
@ -402,26 +408,50 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) :
msg := msg ++ m!"{x} : {← inferType x}"
addMessageContextFull msg
/--
Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns.
Pattern variables are represented using de Bruijn indices.
-/
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) : MetaM EMatchTheorem := do
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
return {
proof, patterns, numParams, symbols
levelParams, origin
}
private def getProofFor (declName : Name) : CoreM Expr := do
let .thmInfo info ← getConstInfo declName
| throwError "`{declName}` is not a theorem"
let us := info.levelParams.map mkLevelParam
return mkConst declName us
/--
Creates an E-matching theorem for `declName` with `numParams` parameters, and the given set of patterns.
Pattern variables are represented using de Bruijn indices.
-/
def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM EMatchTheorem := do
let .thmInfo info ← getConstInfo declName
| throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic"
let us := info.levelParams.map mkLevelParam
let proof := mkConst declName us
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
assert! symbols.all fun s => s matches .const _
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
return {
proof, patterns, numParams, symbols
levelParams := #[]
origin := .decl declName
}
mkEMatchTheoremCore (.decl declName) #[] numParams (← getProofFor declName) patterns
/--
Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the pattern.
-/
def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) : MetaM EMatchTheorem := do
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
let (lhs, rhs) ← match_expr type with
| Eq _ lhs rhs => pure (lhs, rhs)
| Iff lhs rhs => pure (lhs, rhs)
| HEq _ lhs _ rhs => pure (lhs, rhs)
| _ => throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}"
let pat := if useLhs then lhs else rhs
let pat ← preprocessPattern pat normalizePattern
return (xs.size, [pat.abstract xs])
mkEMatchTheoremCore origin levelParams numParams proof patterns
/--
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
@ -430,17 +460,8 @@ creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the
pattern.
-/
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) : MetaM EMatchTheorem := do
let info ← getConstInfo declName
let (numParams, patterns) ← forallTelescopeReducing info.type fun xs type => do
let lhs ← match_expr type with
| Eq _ lhs _ => pure lhs
| Iff lhs _ => pure lhs
| HEq _ lhs _ _ => pure lhs
| _ => throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}"
let lhs ← preprocessPattern lhs normalizePattern
return (xs.size, [lhs.abstract xs])
mkEMatchTheorem declName numParams patterns
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) : MetaM EMatchTheorem := do
mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs
/--
Adds an E-matching theorem to the environment.
@ -460,18 +481,177 @@ def addEMatchEqTheorem (declName : Name) : MetaM Unit := do
def getEMatchTheorems : CoreM EMatchTheorems :=
return ematchTheoremsExt.getState (← getEnv)
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do
private inductive TheoremKind where
| eqLhs | eqRhs | eqBoth | fwd | bwd | default
deriving Inhabited, BEq
private def TheoremKind.toAttribute : TheoremKind → String
| .eqLhs => "[grind =]"
| .eqRhs => "[grind =_]"
| .eqBoth => "[grind _=_]"
| .fwd => "[grind →]"
| .bwd => "[grind ←]"
| .default => "[grind]"
private def TheoremKind.explainFailure : TheoremKind → String
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
| .eqBoth => unreachable! -- eqBoth is a macro
| .fwd => "failed to find patterns in the antecedents of the theorem"
| .bwd => "failed to find patterns in the theorem's conclusion"
| .default => "failed to find patterns"
/-- Returns the types of `xs` that are propositions. -/
private def getPropTypes (xs : Array Expr) : MetaM (Array Expr) :=
xs.filterMapM fun x => do
let type ← inferType x
if (← isProp type) then return some type else return none
/-- State for the (pattern) `CollectorM` monad -/
private structure Collector.State where
/-- Pattern found so far. -/
patterns : Array Expr := #[]
done : Bool := false
private structure Collector.Context where
proof : Expr
xs : Array Expr
/-- Monad for collecting patterns for a theorem. -/
private abbrev CollectorM := ReaderT Collector.Context $ StateRefT Collector.State NormalizePattern.M
/-- Similar to `getPatternFn?`, but operates on expressions that do not contain loose de Bruijn variables. -/
private def isPatternFnCandidate (f : Expr) : CollectorM Bool := do
match f with
| .const declName _ => return !isForbidden declName
| .fvar .. => return !(← read).xs.contains f
| _ => return false
private def addNewPattern (p : Expr) : CollectorM Unit := do
trace[grind.ematch.pattern.search] "found pattern: {ppPattern p}"
let bvarsFound := (← getThe NormalizePattern.State).bvarsFound
let done := (← checkCoverage (← read).proof (← read).xs.size bvarsFound) matches .ok
if done then
trace[grind.ematch.pattern.search] "found full coverage"
modify fun s => { s with patterns := s.patterns.push p, done }
private partial def collect (e : Expr) : CollectorM Unit := do
if (← get).done then return ()
match e with
| .app .. =>
let f := e.getAppFn
if (← isPatternFnCandidate f) then
let saved ← getThe NormalizePattern.State
try
trace[grind.ematch.pattern.search] "candidate: {e}"
let p := e.abstract (← read).xs
unless p.hasLooseBVars do
trace[grind.ematch.pattern.search] "skip, does not contain pattern variables"
return ()
let p ← NormalizePattern.normalizePattern p
if saved.bvarsFound.size < (← getThe NormalizePattern.State).bvarsFound.size then
addNewPattern p
return ()
trace[grind.ematch.pattern.search] "skip, no new variables covered"
-- restore state and continue search
set saved
catch _ =>
-- restore state and continue search
trace[grind.ematch.pattern.search] "skip, exception during normalization"
set saved
let args := e.getAppArgs
for arg in args, flag in (← NormalizePattern.getPatternSupportMask f args.size) do
unless flag do
collect arg
| .forallE _ d b _ =>
if (← pure e.isArrow <&&> isProp d <&&> isProp b) then
collect d
collect b
| _ => return ()
private def collectPatterns? (proof : Expr) (xs : Array Expr) (searchPlaces : Array Expr) : MetaM (Option (List Expr × List HeadIndex)) := do
let go : CollectorM (Option (List Expr)) := do
for place in searchPlaces do
let place ← preprocessPattern place
collect place
if (← get).done then
return some ((← get).patterns.toList)
return none
let (some ps, s) ← go { proof, xs } |>.run' {} |>.run {}
| return none
return some (ps, s.symbols.toList)
private def mkEMatchTheoremWithKind? (origin : Origin) (levelParams : Array Name) (proof : Expr) (kind : TheoremKind) : MetaM (Option EMatchTheorem) := do
if kind == .eqLhs then
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := false) (useLhs := true))
else if kind == .eqRhs then
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := false) (useLhs := false))
let type ← inferType proof
forallTelescopeReducing type fun xs type => do
let searchPlaces ← match kind with
| .fwd =>
let ps ← getPropTypes xs
if ps.isEmpty then
throwError "invalid `grind` forward theorem, theorem `{← origin.pp}` does not have proposional hypotheses"
pure ps
| .bwd => pure #[type]
| .default => pure <| #[type] ++ (← getPropTypes xs)
| _ => unreachable!
go xs searchPlaces
where
go (xs : Array Expr) (searchPlaces : Array Expr) : MetaM (Option EMatchTheorem) := do
let some (patterns, symbols) ← collectPatterns? proof xs searchPlaces
| return none
let numParams := xs.size
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
return some {
proof, patterns, numParams, symbols
levelParams, origin
}
private def getKind (stx : Syntax) : TheoremKind :=
if stx[1].isNone then
.default
else if stx[1][0].getKind == ``Parser.Attr.grindEq then
.eqLhs
else if stx[1][0].getKind == ``Parser.Attr.grindFwd then
.fwd
else if stx[1][0].getKind == ``Parser.Attr.grindEqRhs then
.eqRhs
else if stx[1][0].getKind == ``Parser.Attr.grindEqBoth then
.eqBoth
else
.bwd
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (useLhs := true) : MetaM Unit := do
if (← getConstInfo declName).isTheorem then
ematchTheoremsExt.add (← mkEMatchEqTheorem declName) attrKind
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs)) attrKind
else if let some eqns ← getEqnsFor? declName then
unless useLhs do
throwError "`{declName}` is a definition, you must only use the left-hand side for extracting patterns"
for eqn in eqns do
ematchTheoremsExt.add (← mkEMatchEqTheorem eqn) attrKind
else
throwError "`[grind_eq]` attribute can only be applied to equational theorems or function definitions"
private def addGrindAttr (declName : Name) (attrKind : AttributeKind) (thmKind : TheoremKind) : MetaM Unit := do
if thmKind == .eqLhs then
addGrindEqAttr declName attrKind (useLhs := true)
else if thmKind == .eqRhs then
addGrindEqAttr declName attrKind (useLhs := false)
else if thmKind == .eqBoth then
addGrindEqAttr declName attrKind (useLhs := true)
addGrindEqAttr declName attrKind (useLhs := false)
else if !(← getConstInfo declName).isTheorem then
addGrindEqAttr declName attrKind
else
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind
| throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command"
ematchTheoremsExt.add thm attrKind
builtin_initialize
registerBuiltinAttribute {
name := `grind_eq
name := `grind
descr :=
"The `[grind_eq]` attribute is used to annotate equational theorems and functions.\
When applied to an equational theorem, it marks the theorem for use in heuristic instantiations by the `grind` tactic.\
@ -480,8 +660,8 @@ builtin_initialize
For example, if a theorem `@[grind_eq] theorem foo_idempotent : foo (foo x) = foo x` is annotated,\
`grind` will add an instance of this theorem to the local context whenever it encounters the pattern `foo (foo x)`."
applicationTime := .afterCompilation
add := fun declName _ attrKind =>
addGrindEqAttr declName attrKind |>.run' {}
add := fun declName stx attrKind => do
addGrindAttr declName attrKind (getKind stx) |>.run' {}
}
end Lean.Meta.Grind

View file

@ -69,3 +69,148 @@ info: [grind.ematch.instance] Rtrans: R a d → R d e → R a e
#guard_msgs (info) in
example : R a b → R b c → R c d → R d e → R a d := by
grind
namespace using_grind_fwd
opaque S : Nat → Nat → Prop
/--
error: `@[grind →] theorem using_grind_fwd.StransBad` failed to find patterns in the antecedents of the theorem, consider using different options or the `grind_pattern` command
-/
#guard_msgs (error) in
@[grind→] theorem StransBad (a b c d : Nat) : S a b R a b → S b c → S a c ∧ S b d := sorry
set_option trace.grind.ematch.pattern.search true in
/--
info: [grind.ematch.pattern.search] candidate: S a b
[grind.ematch.pattern.search] found pattern: S #4 #3
[grind.ematch.pattern.search] candidate: R a b
[grind.ematch.pattern.search] skip, no new variables covered
[grind.ematch.pattern.search] candidate: S b c
[grind.ematch.pattern.search] found pattern: S #3 #2
[grind.ematch.pattern.search] found full coverage
[grind.ematch.pattern] Strans: [S #4 #3, S #3 #2]
-/
#guard_msgs (info) in
@[grind→] theorem Strans (a b c : Nat) : S a b R a b → S b c → S a c := sorry
/--
info: [grind.ematch.instance] Strans: S a b R a b → S b c → S a c
-/
#guard_msgs (info) in
example : S a b → S b c → S a c := by
grind
end using_grind_fwd
namespace using_grind_bwd
opaque P : Nat → Prop
opaque Q : Nat → Prop
opaque f : Nat → Nat → Nat
/--
info: [grind.ematch.pattern] pqf: [P (f #2 #1)]
-/
#guard_msgs (info) in
@[grind←] theorem pqf : Q x → P (f x y) := sorry
/--
info: [grind.ematch.instance] pqf: Q a → P (f a b)
-/
#guard_msgs (info) in
example : Q 0 → Q 1 → Q 2 → Q 3 → ¬ P (f a b) → a = 1 → False := by
grind
end using_grind_bwd
namespace using_grind_fwd2
opaque P : Nat → Prop
opaque Q : Nat → Prop
opaque f : Nat → Nat → Nat
/--
error: `@[grind →] theorem using_grind_fwd2.pqfBad` failed to find patterns in the antecedents of the theorem, consider using different options or the `grind_pattern` command
-/
#guard_msgs (error) in
@[grind→] theorem pqfBad : Q x → P (f x y) := sorry
/--
info: [grind.ematch.pattern] pqf: [Q #1]
-/
#guard_msgs (info) in
@[grind→] theorem pqf : Q x → P (f x x) := sorry
/--
info: [grind.ematch.instance] pqf: Q 3 → P (f 3 3)
[grind.ematch.instance] pqf: Q 2 → P (f 2 2)
[grind.ematch.instance] pqf: Q 1 → P (f 1 1)
[grind.ematch.instance] pqf: Q 0 → P (f 0 0)
-/
#guard_msgs (info) in
example : Q 0 → Q 1 → Q 2 → Q 3 → ¬ P (f a a) → a = 1 → False := by
grind
end using_grind_fwd2
namespace using_grind_mixed
opaque P : Nat → Nat → Prop
opaque Q : Nat → Nat → Prop
/--
error: `@[grind →] theorem using_grind_mixed.pqBad1` failed to find patterns in the antecedents of the theorem, consider using different options or the `grind_pattern` command
-/
#guard_msgs (error) in
@[grind→] theorem pqBad1 : P x y → Q x z := sorry
/--
error: `@[grind ←] theorem using_grind_mixed.pqBad2` failed to find patterns in the theorem's conclusion, consider using different options or the `grind_pattern` command
-/
#guard_msgs (error) in
@[grind←] theorem pqBad2 : P x y → Q x z := sorry
/--
info: [grind.ematch.pattern] pqBad: [Q #3 #1, P #3 #2]
-/
#guard_msgs (info) in
@[grind] theorem pqBad : P x y → Q x z := sorry
example : P a b → Q a c := by
grind
end using_grind_mixed
namespace using_grind_rhs
opaque f : Nat → Nat
opaque g : Nat → Nat → Nat
/--
info: [grind.ematch.pattern] fq: [g #0 (f #0)]
-/
#guard_msgs (info) in
@[grind =_]
theorem fq : f x = g x (f x) := sorry
end using_grind_rhs
namespace using_grind_lhs_rhs
opaque f : Nat → Nat
opaque g : Nat → Nat → Nat
/--
info: [grind.ematch.pattern] fq: [f #0]
[grind.ematch.pattern] fq: [g #0 (g #0 #0)]
-/
#guard_msgs (info) in
@[grind _=_]
theorem fq : f x = g x (g x x) := sorry
end using_grind_lhs_rhs

View file

@ -1,6 +1,8 @@
opaque g : Nat → Nat
@[grind_eq] def f (a : Nat) :=
set_option trace.Meta.debug true
@[grind] def f (a : Nat) :=
match a with
| 0 => 10
| x+1 => g (f x)
@ -21,7 +23,7 @@ info: [grind.assert] f (y + 1) = a
example : f (y + 1) = a → a = g (f y):= by
grind
@[grind_eq] def app (xs ys : List α) :=
@[grind] def app (xs ys : List α) :=
match xs with
| [] => ys
| x::xs => x :: app xs ys
@ -43,7 +45,7 @@ example : app [1, 2] ys = xs → xs = 1::2::ys := by
opaque p : Nat → Nat → Prop
opaque q : Nat → Prop
@[grind_eq] theorem pq : p x x ↔ q x := by sorry
@[grind =] theorem pq : p x x ↔ q x := by sorry
/--
info: [grind.assert] p a a
@ -58,7 +60,7 @@ example : p a a → q a := by
opaque appV (xs : Vector α n) (ys : Vector α m) : Vector α (n + m) :=
Vector.append xs ys
@[grind_eq]
@[grind =]
theorem appV_assoc (a : Vector α n) (b : Vector α m) (c : Vector α n') :
HEq (appV a (appV b c)) (appV (appV a b) c) := sorry

View file

@ -20,7 +20,7 @@ grind_pattern List.mem_concat_self => a ∈ xs ++ [a]
def foo (x : Nat) := x + x
/--
error: `foo` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic
error: `foo` is not a theorem
-/
#guard_msgs in
grind_pattern foo => x + x