From c28b0525763b6afabee5e243f37b77a16c0be412 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 20 May 2025 20:32:49 -0400 Subject: [PATCH] feat: `[grind?]` attribute (#8426) This PR adds the attribute `[grind?]`. It is like `[grind]` but displays inferred E-matching patterns. It is a more convinient than writing. Thanks @kim-em for suggesting this feature. ```lean set_option trace.grind.ematch.pattern true ``` This PR also improves some tests, and adds helper function `ENode.isRoot`. --- src/Init/Grind/Tactics.lean | 1 + .../Meta/Tactic/Grind/Arith/Cutsat/Model.lean | 4 +- src/Lean/Meta/Tactic/Grind/Attr.lean | 29 ++++++--- src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean | 59 +++++++++++-------- src/Lean/Meta/Tactic/Grind/Inv.lean | 2 +- src/Lean/Meta/Tactic/Grind/Types.lean | 7 ++- tests/lean/run/grind_attrs.lean | 15 +++++ tests/lean/run/grind_countP.lean | 22 ++++--- tests/lean/run/grind_eq.lean | 9 +++ ...t_dropLast => grind_getLast_dropLast.lean} | 2 +- 10 files changed, 102 insertions(+), 48 deletions(-) rename tests/lean/run/{grind_getLast_dropLast => grind_getLast_dropLast.lean} (87%) diff --git a/src/Init/Grind/Tactics.lean b/src/Init/Grind/Tactics.lean index 9b455bbffb..765e234fda 100644 --- a/src/Init/Grind/Tactics.lean +++ b/src/Init/Grind/Tactics.lean @@ -30,6 +30,7 @@ syntax grindIntro := &"intro " syntax grindExt := &"ext " syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt syntax (name := grind) "grind" (grindMod)? : attr +syntax (name := grind?) "grind?" (grindMod)? : attr end Attr end Lean.Parser diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Model.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Model.lean index efc7269594..d319b6685f 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Model.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Model.lean @@ -95,7 +95,7 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do -- Assign on expressions associated with cutsat terms or interpreted terms for e in goal.exprs do let node ← goal.getENode e - if isSameExpr node.root node.self then + if node.isRoot then if (← isIntNatENode node) then if let some v ← getAssignment? goal node.self then if v.den == 1 then used := used.insert v.num @@ -111,7 +111,7 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do -- Assign the remaining ones with values not used by cutsat for e in goal.exprs do let node ← goal.getENode e - if isSameExpr node.root node.self then + if node.isRoot then if (← isIntNatENode node) then if model[node.self]?.isNone then let v := pickUnusedValue goal model node.self nextVal used diff --git a/src/Lean/Meta/Tactic/Grind/Attr.lean b/src/Lean/Meta/Tactic/Grind/Attr.lean index 2388c40950..411aa7497b 100644 --- a/src/Lean/Meta/Tactic/Grind/Attr.lean +++ b/src/Lean/Meta/Tactic/Grind/Attr.lean @@ -49,11 +49,20 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do def throwInvalidUsrModifier : CoreM α := throwError "the modifier `usr` is only relevant in parameters for `grind only`" -builtin_initialize +/-- +Auxiliary function for registering `grind` and `grind?` attributes. +The `grind?` is an alias for `grind` which displays patterns using `logInfo`. +It is just a convenience for users. +-/ +private def registerGrindAttr (showInfo : Bool) : IO Unit := registerBuiltinAttribute { - name := `grind + name := if showInfo then `grind? else `grind descr := - "The `[grind]` attribute is used to annotate declarations.\ + let header := if showInfo then + "The `[grind?]` attribute is identical to the `[grind]` attribute, but displays inferred pattern information." + else + "The `[grind]` attribute is used to annotate declarations." + header ++ "\ \ When applied to an equational theorem, `[grind =]`, `[grind =_]`, or `[grind _=_]`\ will mark the theorem for use in heuristic instantiations by the `grind` tactic, @@ -73,12 +82,12 @@ builtin_initialize add := fun declName stx attrKind => MetaM.run' do match (← getAttrKindFromOpt stx) with | .ematch .user => throwInvalidUsrModifier - | .ematch k => addEMatchAttr declName attrKind k + | .ematch k => addEMatchAttr declName attrKind k (showInfo := showInfo) | .cases eager => addCasesAttr declName eager attrKind | .intro => if let some info ← isCasesAttrPredicateCandidate? declName false then for ctor in info.ctors do - addEMatchAttr ctor attrKind .default + addEMatchAttr ctor attrKind .default (showInfo := showInfo) else throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate" | .ext => addExtAttr declName attrKind @@ -89,10 +98,12 @@ builtin_initialize -- If it is an inductive predicate, -- we also add the constructors (intro rules) as E-matching rules for ctor in info.ctors do - addEMatchAttr ctor attrKind .default + addEMatchAttr ctor attrKind .default (showInfo := showInfo) else - addEMatchAttr declName attrKind .default + addEMatchAttr declName attrKind .default (showInfo := showInfo) erase := fun declName => MetaM.run' do + if showInfo then + throwError "`[grind?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[grind]` instead" if (← isCasesAttrCandidate declName false) then eraseCasesAttr declName else if (← isExtTheorem declName) then @@ -101,4 +112,8 @@ builtin_initialize eraseEMatchAttr declName } +builtin_initialize + registerGrindAttr true + registerGrindAttr false + end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index 8bc2179594..a2051d74f6 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -589,18 +589,24 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) : msg := msg ++ m!"{x} : {← inferType x}" addMessageContextFull msg +private def logPatternWhen (showInfo : Bool) (origin : Origin) (patterns : List Expr) : MetaM Unit := do + if showInfo then + logInfo m!"{← origin.pp}: {patterns.map ppPattern}" + /-- Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns. Pattern variables are represented using de Bruijn indices. -/ -def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) (kind : EMatchTheoremKind) : MetaM EMatchTheorem := do +def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) + (patterns : List Expr) (kind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns if symbols.isEmpty then throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing" - trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}" + trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}" if let .missing pos ← checkCoverage proof numParams bvarFound then let pats : MessageData := m!"{patterns.map ppPattern}" throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}" + logPatternWhen showInfo origin patterns return { proof, patterns, numParams, symbols levelParams, origin, kind @@ -627,7 +633,7 @@ Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs creates an E-matching pattern for it using `addEMatchTheorem n [lhs]` If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the pattern. -/ -def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) : MetaM EMatchTheorem := do +def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) (showInfo := false) : MetaM EMatchTheorem := do let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do let (lhs, rhs) ← match_expr type with | Eq _ lhs rhs => pure (lhs, rhs) @@ -640,15 +646,15 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: after preprocessing: {pat}, {← normalize pat normConfig}" let pats := splitWhileForbidden (pat.abstract xs) return (xs.size, pats) - mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs) + mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs) (showInfo := showInfo) -def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do +def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (showInfo := false) : MetaM EMatchTheorem := do let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do let_expr f@Eq α lhs rhs := type | throwError "invalid E-matching `←=` theorem, conclusion must be an equality{indentExpr type}" let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs) return (xs.size, [pat.abstract xs]) - mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd + mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd (showInfo := showInfo) /-- Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`, @@ -657,8 +663,8 @@ creates an E-matching pattern for it using `addEMatchTheorem n [lhs]` If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the pattern. -/ -def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) : MetaM EMatchTheorem := do - mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs +def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) (showInfo := false) : MetaM EMatchTheorem := do + mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs (showInfo := showInfo) /-- Adds an E-matching theorem to the environment. @@ -844,13 +850,13 @@ since the theorem is already in the `grind` state and there is nothing to be ins -/ def mkEMatchTheoremWithKind? (origin : Origin) (levelParams : Array Name) (proof : Expr) (kind : EMatchTheoremKind) - (groundPatterns := true) : MetaM (Option EMatchTheorem) := do + (groundPatterns := true) (showInfo := false) : MetaM (Option EMatchTheorem) := do if kind == .eqLhs then - return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true)) + return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true) (showInfo := showInfo)) else if kind == .eqRhs then - return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false)) + return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false) (showInfo := showInfo)) else if kind == .eqBwd then - return (← mkEMatchEqBwdTheoremCore origin levelParams proof) + return (← mkEMatchEqBwdTheoremCore origin levelParams proof (showInfo := showInfo)) let type ← inferType proof /- Remark: we should not use `forallTelescopeReducing` (with default reducibility) here @@ -894,25 +900,26 @@ where return none let numParams := xs.size trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}" + logPatternWhen showInfo origin patterns return some { proof, patterns, numParams, symbols levelParams, origin, kind } -def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) : MetaM EMatchTheorem := do - let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind +def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do + let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind (showInfo := showInfo) | throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command" return thm -def mkEMatchEqTheoremsForDef? (declName : Name) : MetaM (Option (Array EMatchTheorem)) := do +def mkEMatchEqTheoremsForDef? (declName : Name) (showInfo := false) : MetaM (Option (Array EMatchTheorem)) := do let some eqns ← getEqnsFor? declName | return none eqns.mapM fun eqn => do - mkEMatchEqTheorem eqn (normalizePattern := true) + mkEMatchEqTheorem eqn (normalizePattern := true) (showInfo := showInfo) -private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) : MetaM Unit := do +private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do if wasOriginallyTheorem (← getEnv) declName then - ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs)) attrKind - else if let some thms ← mkEMatchEqTheoremsForDef? declName then + ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (showInfo := showInfo)) attrKind + else if let some thms ← mkEMatchEqTheoremsForDef? declName (showInfo := showInfo) then unless useLhs do throwError "`{declName}` is a definition, you must only use the left-hand side for extracting patterns" thms.forM (ematchTheoremsExt.add · attrKind) @@ -935,20 +942,20 @@ def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMat throwErr return s.erase <| .decl declName -def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) : MetaM Unit := do +def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM Unit := do if thmKind == .eqLhs then - addGrindEqAttr declName attrKind thmKind (useLhs := true) + addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo) else if thmKind == .eqRhs then - addGrindEqAttr declName attrKind thmKind (useLhs := false) + addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo) else if thmKind == .eqBoth then - addGrindEqAttr declName attrKind thmKind (useLhs := true) - addGrindEqAttr declName attrKind thmKind (useLhs := false) + addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo) + addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo) else let info ← getConstInfo declName if !wasOriginallyTheorem (← getEnv) declName && !info.isCtor && !info.isAxiom then - addGrindEqAttr declName attrKind thmKind + addGrindEqAttr declName attrKind thmKind (showInfo := showInfo) else - let thm ← mkEMatchTheoremForDecl declName thmKind + let thm ← mkEMatchTheoremForDecl declName thmKind (showInfo := showInfo) ematchTheoremsExt.add thm attrKind def eraseEMatchAttr (declName : Name) : MetaM Unit := do diff --git a/src/Lean/Meta/Tactic/Grind/Inv.lean b/src/Lean/Meta/Tactic/Grind/Inv.lean index 0e8558bb3c..66fc809758 100644 --- a/src/Lean/Meta/Tactic/Grind/Inv.lean +++ b/src/Lean/Meta/Tactic/Grind/Inv.lean @@ -123,7 +123,7 @@ def checkInvariants (expensive := false) : GoalM Unit := do for e in (← getExprs) do let node ← getENode e checkParents node.self - if isSameExpr node.self node.root then + if node.isRoot then checkEqc node if expensive then checkPtrEqImpliesStructEq diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 0491db1abc..a676e7d969 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -343,6 +343,9 @@ structure ENode where -- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3. deriving Inhabited, Repr +def ENode.isRoot (n : ENode) := + isSameExpr n.self n.root + def ENode.isCongrRoot (n : ENode) := isSameExpr n.self n.congr @@ -1250,7 +1253,7 @@ def filterENodes (p : ENode → GoalM Bool) : GoalM (Array ENode) := do def forEachEqcRoot (f : ENode → GoalM Unit) : GoalM Unit := do for e in (← getExprs) do let n ← getENode e - if isSameExpr n.self n.root then + if n.isRoot then f n abbrev Propagator := Expr → GoalM Unit @@ -1302,7 +1305,7 @@ partial def Goal.getEqcs (goal : Goal) : List (List Expr) := Id.run do let mut r : List (List Expr) := [] for e in goal.exprs do let some node := goal.getENode? e | pure () - if isSameExpr node.root node.self then + if node.isRoot then r := goal.getEqc node.self :: r return r diff --git a/tests/lean/run/grind_attrs.lean b/tests/lean/run/grind_attrs.lean index 7c97e139fd..f12fced84d 100644 --- a/tests/lean/run/grind_attrs.lean +++ b/tests/lean/run/grind_attrs.lean @@ -44,3 +44,18 @@ set_option trace.grind.ematch.pattern true in set_option trace.grind.ematch.pattern true in @[grind =>] theorem State.update_le_update (h : State.le σ' σ) : State.le (σ'.update x v) (σ.update x v) := sorry + + +namespace Foo + +/-- info: Rtrans: [R #4 #3, R #3 #2] -/ +#guard_msgs (info) in +@[grind? ->] +axiom Rtrans {x y z : Nat} : R x y → R y z → R x z + +/-- info: Rtrans': [R #4 #3, R #3 #2] -/ +#guard_msgs (info) in +@[grind? →] +axiom Rtrans' {x y z : Nat} : R x y → R y z → R x z + +end Foo diff --git a/tests/lean/run/grind_countP.lean b/tests/lean/run/grind_countP.lean index 596f9a21b2..f329869d2b 100644 --- a/tests/lean/run/grind_countP.lean +++ b/tests/lean/run/grind_countP.lean @@ -5,15 +5,19 @@ attribute [grind] List.countP_nil List.countP_cons theorem List.countP_le_countP (hpq : ∀ x ∈ l, P x → Q x) : l.countP P ≤ l.countP Q := by - induction l with - | nil => grind - | cons x xs ih => - grind + induction l <;> grind + +-- TODO: how to explain to the user that `l.countP P ≤ l.countP Q` is a bad pattern +grind_pattern List.countP_le_countP => l.countP P, l.countP Q theorem List.countP_lt_countP (hpq : ∀ x ∈ l, P x → Q x) (y:α) (hx: y ∈ l) (hxP : P y = false) (hxQ : Q y) : l.countP P < l.countP Q := by - induction l with - | nil => grind - | cons x xs ih => - have : xs.countP P ≤ xs.countP Q := countP_le_countP (by grind) - grind + induction l <;> grind + +/-- +info: List.countP_nil: [@List.countP #1 #0 (@List.nil _)] +--- +info: List.countP_cons: [@List.countP #3 #2 (@List.cons _ #1 #0)] +-/ +#guard_msgs (info) in +attribute [grind?] List.countP_nil List.countP_cons diff --git a/tests/lean/run/grind_eq.lean b/tests/lean/run/grind_eq.lean index c3163306c4..c71f7e65bd 100644 --- a/tests/lean/run/grind_eq.lean +++ b/tests/lean/run/grind_eq.lean @@ -76,3 +76,12 @@ trace: [grind.assert] x1 = appV a_2 b #guard_msgs (trace) in example : x1 = appV a b → x2 = appV x1 c → x3 = appV b c → x4 = appV a x3 → HEq x2 x4 := by grind + + +/-- +info: appV_assoc': [@appV #6 #5 (@HAdd.hAdd `[Nat] `[Nat] `[Nat] `[instHAdd] #4 #3) #2 (@appV _ #4 #3 #1 #0)] +-/ +#guard_msgs (info) in +@[grind? =] +theorem appV_assoc' (a : Vector α n) (b : Vector α m) (c : Vector α n') : + HEq (appV a (appV b c)) (appV (appV a b) c) := sorry diff --git a/tests/lean/run/grind_getLast_dropLast b/tests/lean/run/grind_getLast_dropLast.lean similarity index 87% rename from tests/lean/run/grind_getLast_dropLast rename to tests/lean/run/grind_getLast_dropLast.lean index e2bb5350d2..08773ab765 100644 --- a/tests/lean/run/grind_getLast_dropLast +++ b/tests/lean/run/grind_getLast_dropLast.lean @@ -8,7 +8,7 @@ theorem length_pos_of_ne_nil {l : List α} (h : l ≠ []) : 0 < l.length := by theorem getLast?_dropLast {xs : List α} : xs.dropLast.getLast? = if xs.length ≤ 1 then none else xs[xs.length - 2]? := by - grind (splits := 9) only [List.getElem?_eq_none, List.getElem?_reverse, getLast?_eq_getElem?, + grind (splits := 15) only [List.getElem?_eq_none, List.getElem?_reverse, getLast?_eq_getElem?, List.head?_eq_getLast?_reverse, getElem?_dropLast, List.getLast?_reverse, List.length_dropLast, List.length_reverse, length_nil, List.reverse_reverse, head?_nil, List.getElem?_eq_none, length_pos_of_ne_nil, getLast?_nil, List.head?_reverse, List.getLast?_eq_head?_reverse,