feat: [grind?] attribute (#8426)

This PR adds the attribute `[grind?]`. It is like `[grind]` but displays
inferred E-matching patterns. It is a more convinient than writing.
Thanks @kim-em for suggesting this feature.
```lean
set_option trace.grind.ematch.pattern true
```
This PR also improves some tests, and adds helper function
`ENode.isRoot`.
This commit is contained in:
Leonardo de Moura 2025-05-20 20:32:49 -04:00 committed by GitHub
parent a541b8e75e
commit c28b052576
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 102 additions and 48 deletions

View file

@ -30,6 +30,7 @@ syntax grindIntro := &"intro "
syntax grindExt := &"ext "
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt
syntax (name := grind) "grind" (grindMod)? : attr
syntax (name := grind?) "grind?" (grindMod)? : attr
end Attr
end Lean.Parser

View file

@ -95,7 +95,7 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do
-- Assign on expressions associated with cutsat terms or interpreted terms
for e in goal.exprs do
let node ← goal.getENode e
if isSameExpr node.root node.self then
if node.isRoot then
if (← isIntNatENode node) then
if let some v ← getAssignment? goal node.self then
if v.den == 1 then used := used.insert v.num
@ -111,7 +111,7 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do
-- Assign the remaining ones with values not used by cutsat
for e in goal.exprs do
let node ← goal.getENode e
if isSameExpr node.root node.self then
if node.isRoot then
if (← isIntNatENode node) then
if model[node.self]?.isNone then
let v := pickUnusedValue goal model node.self nextVal used

View file

@ -49,11 +49,20 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do
def throwInvalidUsrModifier : CoreM α :=
throwError "the modifier `usr` is only relevant in parameters for `grind only`"
builtin_initialize
/--
Auxiliary function for registering `grind` and `grind?` attributes.
The `grind?` is an alias for `grind` which displays patterns using `logInfo`.
It is just a convenience for users.
-/
private def registerGrindAttr (showInfo : Bool) : IO Unit :=
registerBuiltinAttribute {
name := `grind
name := if showInfo then `grind? else `grind
descr :=
"The `[grind]` attribute is used to annotate declarations.\
let header := if showInfo then
"The `[grind?]` attribute is identical to the `[grind]` attribute, but displays inferred pattern information."
else
"The `[grind]` attribute is used to annotate declarations."
header ++ "\
\
When applied to an equational theorem, `[grind =]`, `[grind =_]`, or `[grind _=_]`\
will mark the theorem for use in heuristic instantiations by the `grind` tactic,
@ -73,12 +82,12 @@ builtin_initialize
add := fun declName stx attrKind => MetaM.run' do
match (← getAttrKindFromOpt stx) with
| .ematch .user => throwInvalidUsrModifier
| .ematch k => addEMatchAttr declName attrKind k
| .ematch k => addEMatchAttr declName attrKind k (showInfo := showInfo)
| .cases eager => addCasesAttr declName eager attrKind
| .intro =>
if let some info ← isCasesAttrPredicateCandidate? declName false then
for ctor in info.ctors do
addEMatchAttr ctor attrKind .default
addEMatchAttr ctor attrKind .default (showInfo := showInfo)
else
throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate"
| .ext => addExtAttr declName attrKind
@ -89,10 +98,12 @@ builtin_initialize
-- If it is an inductive predicate,
-- we also add the constructors (intro rules) as E-matching rules
for ctor in info.ctors do
addEMatchAttr ctor attrKind .default
addEMatchAttr ctor attrKind .default (showInfo := showInfo)
else
addEMatchAttr declName attrKind .default
addEMatchAttr declName attrKind .default (showInfo := showInfo)
erase := fun declName => MetaM.run' do
if showInfo then
throwError "`[grind?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[grind]` instead"
if (← isCasesAttrCandidate declName false) then
eraseCasesAttr declName
else if (← isExtTheorem declName) then
@ -101,4 +112,8 @@ builtin_initialize
eraseEMatchAttr declName
}
builtin_initialize
registerGrindAttr true
registerGrindAttr false
end Lean.Meta.Grind

View file

@ -589,18 +589,24 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) :
msg := msg ++ m!"{x} : {← inferType x}"
addMessageContextFull msg
private def logPatternWhen (showInfo : Bool) (origin : Origin) (patterns : List Expr) : MetaM Unit := do
if showInfo then
logInfo m!"{← origin.pp}: {patterns.map ppPattern}"
/--
Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns.
Pattern variables are represented using de Bruijn indices.
-/
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) (kind : EMatchTheoremKind) : MetaM EMatchTheorem := do
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr)
(patterns : List Expr) (kind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
if symbols.isEmpty then
throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
logPatternWhen showInfo origin patterns
return {
proof, patterns, numParams, symbols
levelParams, origin, kind
@ -627,7 +633,7 @@ Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the pattern.
-/
def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) : MetaM EMatchTheorem := do
def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) (showInfo := false) : MetaM EMatchTheorem := do
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
let (lhs, rhs) ← match_expr type with
| Eq _ lhs rhs => pure (lhs, rhs)
@ -640,15 +646,15 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: after preprocessing: {pat}, {← normalize pat normConfig}"
let pats := splitWhileForbidden (pat.abstract xs)
return (xs.size, pats)
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs)
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs) (showInfo := showInfo)
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (showInfo := false) : MetaM EMatchTheorem := do
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
let_expr f@Eq α lhs rhs := type
| throwError "invalid E-matching `←=` theorem, conclusion must be an equality{indentExpr type}"
let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs)
return (xs.size, [pat.abstract xs])
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd (showInfo := showInfo)
/--
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
@ -657,8 +663,8 @@ creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the
pattern.
-/
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) : MetaM EMatchTheorem := do
mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) (showInfo := false) : MetaM EMatchTheorem := do
mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs (showInfo := showInfo)
/--
Adds an E-matching theorem to the environment.
@ -844,13 +850,13 @@ since the theorem is already in the `grind` state and there is nothing to be ins
-/
def mkEMatchTheoremWithKind?
(origin : Origin) (levelParams : Array Name) (proof : Expr) (kind : EMatchTheoremKind)
(groundPatterns := true) : MetaM (Option EMatchTheorem) := do
(groundPatterns := true) (showInfo := false) : MetaM (Option EMatchTheorem) := do
if kind == .eqLhs then
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true))
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true) (showInfo := showInfo))
else if kind == .eqRhs then
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false))
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false) (showInfo := showInfo))
else if kind == .eqBwd then
return (← mkEMatchEqBwdTheoremCore origin levelParams proof)
return (← mkEMatchEqBwdTheoremCore origin levelParams proof (showInfo := showInfo))
let type ← inferType proof
/-
Remark: we should not use `forallTelescopeReducing` (with default reducibility) here
@ -894,25 +900,26 @@ where
return none
let numParams := xs.size
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
logPatternWhen showInfo origin patterns
return some {
proof, patterns, numParams, symbols
levelParams, origin, kind
}
def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) : MetaM EMatchTheorem := do
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind
def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind (showInfo := showInfo)
| throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command"
return thm
def mkEMatchEqTheoremsForDef? (declName : Name) : MetaM (Option (Array EMatchTheorem)) := do
def mkEMatchEqTheoremsForDef? (declName : Name) (showInfo := false) : MetaM (Option (Array EMatchTheorem)) := do
let some eqns ← getEqnsFor? declName | return none
eqns.mapM fun eqn => do
mkEMatchEqTheorem eqn (normalizePattern := true)
mkEMatchEqTheorem eqn (normalizePattern := true) (showInfo := showInfo)
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) : MetaM Unit := do
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do
if wasOriginallyTheorem (← getEnv) declName then
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs)) attrKind
else if let some thms ← mkEMatchEqTheoremsForDef? declName then
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (showInfo := showInfo)) attrKind
else if let some thms ← mkEMatchEqTheoremsForDef? declName (showInfo := showInfo) then
unless useLhs do
throwError "`{declName}` is a definition, you must only use the left-hand side for extracting patterns"
thms.forM (ematchTheoremsExt.add · attrKind)
@ -935,20 +942,20 @@ def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMat
throwErr
return s.erase <| .decl declName
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) : MetaM Unit := do
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM Unit := do
if thmKind == .eqLhs then
addGrindEqAttr declName attrKind thmKind (useLhs := true)
addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
else if thmKind == .eqRhs then
addGrindEqAttr declName attrKind thmKind (useLhs := false)
addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
else if thmKind == .eqBoth then
addGrindEqAttr declName attrKind thmKind (useLhs := true)
addGrindEqAttr declName attrKind thmKind (useLhs := false)
addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
else
let info ← getConstInfo declName
if !wasOriginallyTheorem (← getEnv) declName && !info.isCtor && !info.isAxiom then
addGrindEqAttr declName attrKind thmKind
addGrindEqAttr declName attrKind thmKind (showInfo := showInfo)
else
let thm ← mkEMatchTheoremForDecl declName thmKind
let thm ← mkEMatchTheoremForDecl declName thmKind (showInfo := showInfo)
ematchTheoremsExt.add thm attrKind
def eraseEMatchAttr (declName : Name) : MetaM Unit := do

View file

@ -123,7 +123,7 @@ def checkInvariants (expensive := false) : GoalM Unit := do
for e in (← getExprs) do
let node ← getENode e
checkParents node.self
if isSameExpr node.self node.root then
if node.isRoot then
checkEqc node
if expensive then
checkPtrEqImpliesStructEq

View file

@ -343,6 +343,9 @@ structure ENode where
-- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3.
deriving Inhabited, Repr
def ENode.isRoot (n : ENode) :=
isSameExpr n.self n.root
def ENode.isCongrRoot (n : ENode) :=
isSameExpr n.self n.congr
@ -1250,7 +1253,7 @@ def filterENodes (p : ENode → GoalM Bool) : GoalM (Array ENode) := do
def forEachEqcRoot (f : ENode → GoalM Unit) : GoalM Unit := do
for e in (← getExprs) do
let n ← getENode e
if isSameExpr n.self n.root then
if n.isRoot then
f n
abbrev Propagator := Expr → GoalM Unit
@ -1302,7 +1305,7 @@ partial def Goal.getEqcs (goal : Goal) : List (List Expr) := Id.run do
let mut r : List (List Expr) := []
for e in goal.exprs do
let some node := goal.getENode? e | pure ()
if isSameExpr node.root node.self then
if node.isRoot then
r := goal.getEqc node.self :: r
return r

View file

@ -44,3 +44,18 @@ set_option trace.grind.ematch.pattern true in
set_option trace.grind.ematch.pattern true in
@[grind =>] theorem State.update_le_update (h : State.le σ' σ) : State.le (σ'.update x v) (σ.update x v) :=
sorry
namespace Foo
/-- info: Rtrans: [R #4 #3, R #3 #2] -/
#guard_msgs (info) in
@[grind? ->]
axiom Rtrans {x y z : Nat} : R x y → R y z → R x z
/-- info: Rtrans': [R #4 #3, R #3 #2] -/
#guard_msgs (info) in
@[grind? →]
axiom Rtrans' {x y z : Nat} : R x y → R y z → R x z
end Foo

View file

@ -5,15 +5,19 @@ attribute [grind] List.countP_nil List.countP_cons
theorem List.countP_le_countP (hpq : ∀ x ∈ l, P x → Q x) :
l.countP P ≤ l.countP Q := by
induction l with
| nil => grind
| cons x xs ih =>
grind
induction l <;> grind
-- TODO: how to explain to the user that `l.countP P ≤ l.countP Q` is a bad pattern
grind_pattern List.countP_le_countP => l.countP P, l.countP Q
theorem List.countP_lt_countP (hpq : ∀ x ∈ l, P x → Q x) (y:α) (hx: y ∈ l) (hxP : P y = false) (hxQ : Q y) :
l.countP P < l.countP Q := by
induction l with
| nil => grind
| cons x xs ih =>
have : xs.countP P ≤ xs.countP Q := countP_le_countP (by grind)
grind
induction l <;> grind
/--
info: List.countP_nil: [@List.countP #1 #0 (@List.nil _)]
---
info: List.countP_cons: [@List.countP #3 #2 (@List.cons _ #1 #0)]
-/
#guard_msgs (info) in
attribute [grind?] List.countP_nil List.countP_cons

View file

@ -76,3 +76,12 @@ trace: [grind.assert] x1 = appV a_2 b
#guard_msgs (trace) in
example : x1 = appV a b → x2 = appV x1 c → x3 = appV b c → x4 = appV a x3 → HEq x2 x4 := by
grind
/--
info: appV_assoc': [@appV #6 #5 (@HAdd.hAdd `[Nat] `[Nat] `[Nat] `[instHAdd] #4 #3) #2 (@appV _ #4 #3 #1 #0)]
-/
#guard_msgs (info) in
@[grind? =]
theorem appV_assoc' (a : Vector α n) (b : Vector α m) (c : Vector α n') :
HEq (appV a (appV b c)) (appV (appV a b) c) := sorry

View file

@ -8,7 +8,7 @@ theorem length_pos_of_ne_nil {l : List α} (h : l ≠ []) : 0 < l.length := by
theorem getLast?_dropLast {xs : List α} :
xs.dropLast.getLast? = if xs.length ≤ 1 then none else xs[xs.length - 2]? := by
grind (splits := 9) only [List.getElem?_eq_none, List.getElem?_reverse, getLast?_eq_getElem?,
grind (splits := 15) only [List.getElem?_eq_none, List.getElem?_reverse, getLast?_eq_getElem?,
List.head?_eq_getLast?_reverse, getElem?_dropLast, List.getLast?_reverse, List.length_dropLast,
List.length_reverse, length_nil, List.reverse_reverse, head?_nil, List.getElem?_eq_none,
length_pos_of_ne_nil, getLast?_nil, List.head?_reverse, List.getLast?_eq_head?_reverse,