From 0d2a574f96fe750660aded1bef28b8424671653e Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 21 Dec 2025 18:57:25 -0800 Subject: [PATCH] feat: user-defined `grind` attributes (#11765) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements user-defined `grind` attributes. They are useful for users that want to implement tactics using the `grind` infrastructure (e.g., `progress*` in Aeneas). New `grind` attributes are declared using the command ```lean register_grind_attr my_grind ``` The command is similar to `register_simp_attr`. After the new attribute is declared. Recall that similar to `register_simp_attr`, the new attribute cannot be used in the same file it is declared. ```lean opaque f : Nat → Nat opaque g : Nat → Nat @[my_grind] theorem fax : f (f x) = f x := sorry example theorem fax2 : f (f (f x)) = f x := by fail_if_success grind grind [my_grind] ``` TODO: remove leftovers after update stage0 --- src/Init/Grind/Attr.lean | 1 + src/Lean/Elab/Tactic/Grind/Lint.lean | 7 +- src/Lean/Elab/Tactic/Grind/Main.lean | 13 +- src/Lean/Elab/Tactic/Grind/Param.lean | 74 ++++-- src/Lean/Meta/Tactic/Grind.lean | 1 + src/Lean/Meta/Tactic/Grind/Attr.lean | 217 ++++++++++++----- src/Lean/Meta/Tactic/Grind/Cases.lean | 21 +- src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean | 202 ++++++---------- src/Lean/Meta/Tactic/Grind/ExtAttr.lean | 14 +- src/Lean/Meta/Tactic/Grind/Extension.lean | 221 ++++++++++++++++++ src/Lean/Meta/Tactic/Grind/FunCC.lean | 3 + src/Lean/Meta/Tactic/Grind/Injective.lean | 26 +-- src/Lean/Meta/Tactic/Grind/Internalize.lean | 10 +- src/Lean/Meta/Tactic/Grind/Intro.lean | 8 +- src/Lean/Meta/Tactic/Grind/Main.lean | 51 +++- .../Meta/Tactic/Grind/PropagatorAttr.lean | 3 + .../Meta/Tactic/Grind/RegisterCommand.lean | 30 +++ src/Lean/Meta/Tactic/Grind/SimpUtil.lean | 11 +- src/Lean/Meta/Tactic/Grind/Theorems.lean | 32 +++ src/Lean/Meta/Tactic/Grind/Types.lean | 22 +- src/Lean/Meta/Tactic/LibrarySearch.lean | 2 +- src/Lean/Meta/Tactic/Try/Collect.lean | 2 +- stage0/src/stdlib_flags.h | 1 + tests/lean/run/sharecommon_mpz.lean | 2 +- tests/pkg/user_attr/UserAttr/BlaAttr.lean | 2 + tests/pkg/user_attr/UserAttr/Tst.lean | 19 ++ 26 files changed, 720 insertions(+), 275 deletions(-) create mode 100644 src/Lean/Meta/Tactic/Grind/Extension.lean create mode 100644 src/Lean/Meta/Tactic/Grind/RegisterCommand.lean diff --git a/src/Init/Grind/Attr.lean b/src/Init/Grind/Attr.lean index 0f2c0053c4..621785e629 100644 --- a/src/Init/Grind/Attr.lean +++ b/src/Init/Grind/Attr.lean @@ -255,6 +255,7 @@ theorem fg_eq (h : x > 0) : f (g x) = x -- With minimal subexpression: @[grind! <-] theorem fg_eq (h : x > 0) : f (g x) = x -- Pattern selected: `g x` +``` -/ syntax (name := grind!) "grind!" (ppSpace grindMod)? : attr /-- diff --git a/src/Lean/Elab/Tactic/Grind/Lint.lean b/src/Lean/Elab/Tactic/Grind/Lint.lean index d1d361d5d0..fe77210f35 100644 --- a/src/Lean/Elab/Tactic/Grind/Lint.lean +++ b/src/Lean/Elab/Tactic/Grind/Lint.lean @@ -88,15 +88,14 @@ def mkConfig (items : Array (TSyntax `Lean.Parser.Tactic.configItem)) : TermElab elabConfigItems defaultConfig items def mkParams (config : Grind.Config) : MetaM Params := do - let params ← Meta.Grind.mkParams config - let casesTypes ← Grind.getCasesTypes - let mut ematch ← getEMatchTheorems + let params ← Meta.Grind.mkDefaultParams config + let mut ematch := params.extensions[0]!.ematch for declName in muteExt.getState (← getEnv) do try ematch ← ematch.eraseDecl declName catch _ => pure () -- Ignore failures here. - return { params with ematch, casesTypes } + return { params with extensions[0].ematch := ematch } /-- Returns the total number of generated instances. -/ def sum (cs : PHashMap Grind.Origin Nat) : Nat := Id.run do diff --git a/src/Lean/Elab/Tactic/Grind/Main.lean b/src/Lean/Elab/Tactic/Grind/Main.lean index 6b2846433d..21284137b7 100644 --- a/src/Lean/Elab/Tactic/Grind/Main.lean +++ b/src/Lean/Elab/Tactic/Grind/Main.lean @@ -210,22 +210,13 @@ def elabGrindParamsAndSuggestions def mkGrindParams (config : Grind.Config) (only : Bool) (ps : TSyntaxArray ``Parser.Tactic.grindParam) (mvarId : MVarId) : TermElabM Grind.Params := do - let params ← Grind.mkParams config - let ematch ← if only then pure default else Grind.getEMatchTheorems - let inj ← if only then pure default else Grind.getInjectiveTheorems - /- - **Note**: We used to skip the global cases attribute when `only = true`, but - this is not very effective. We now use anchors to restrict the set of case-splits. - -/ - let casesTypes ← Grind.getCasesTypes - let funCCs ← Grind.getFunCCSet - let params := { params with ematch, casesTypes, inj, funCCs } + let params ← if only then Grind.mkOnlyParams config else Grind.mkDefaultParams config let suggestions ← if config.suggestions then LibrarySuggestions.select mvarId { caller := some "grind" } else pure #[] let mut params ← elabGrindParamsAndSuggestions params ps suggestions (only := only) (lax := config.lax) - trace[grind.debug.inj] "{params.inj.getOrigins.map (·.pp)}" + trace[grind.debug.inj] "{params.extensions[0]!.inj.getOrigins.map (·.pp)}" if params.anchorRefs?.isSome then /- **Note**: anchors are automatically computed in interactive mode where diff --git a/src/Lean/Elab/Tactic/Grind/Param.lean b/src/Lean/Elab/Tactic/Grind/Param.lean index 45501a39d2..4290f06c37 100644 --- a/src/Lean/Elab/Tactic/Grind/Param.lean +++ b/src/Lean/Elab/Tactic/Grind/Param.lean @@ -20,7 +20,49 @@ open Meta `grind` parameter elaboration -/ -def warnRedundantEMatchArg (s : Grind.EMatchTheorems) (declName : Name) : MetaM Unit := do +def _root_.Lean.Meta.Grind.Params.insertCasesTypes (params : Grind.Params) (declName : Name) (eager : Bool) : Grind.Params := + { params with extensions := params.extensions.modify 0 fun ext => { ext with casesTypes := ext.casesTypes.insert declName eager } } + +def _root_.Lean.Meta.Grind.Params.eraseCasesTypes (params : Grind.Params) (declName : Name) : CoreM Grind.Params := do + unless params.extensions.any fun ext => ext.casesTypes.contains declName do + Grind.throwNotMarkedWithGrindAttribute declName + return { params with extensions := params.extensions.modify 0 fun ext => { ext with casesTypes := ext.casesTypes.erase declName } } + +def _root_.Lean.Meta.Grind.Params.insertFunCC (params : Grind.Params) (declName : Name) : Grind.Params := + { params with extensions := params.extensions.modify 0 fun ext => { ext with funCC := ext.funCC.insert declName } } + +def _root_.Lean.Meta.Grind.Params.containsEMatch (params : Grind.Params) (declName : Name) : Bool := + params.extensions.any fun ext => ext.ematch.contains (.decl declName) + +def _root_.Lean.Meta.Grind.Params.eraseEMatchCore (params : Grind.Params) (declName : Name) : Grind.Params := + { params with extensions := params.extensions.modify 0 fun ext => { ext with ematch := ext.ematch.erase (.decl declName) } } + +def _root_.Lean.Meta.Grind.Params.eraseEMatch (params : Grind.Params) (declName : Name) : MetaM Grind.Params := do + if !wasOriginallyTheorem (← getEnv) declName then + if let some eqns ← getEqnsFor? declName then + unless eqns.all fun eqn => params.containsEMatch eqn do + Grind.throwNotMarkedWithGrindAttribute declName + return eqns.foldl (init := params) fun params eqn => params.eraseEMatchCore eqn + else + Grind.throwNotMarkedWithGrindAttribute declName + else + unless params.containsEMatch declName do + Grind.throwNotMarkedWithGrindAttribute declName + return params.eraseEMatchCore declName + +def _root_.Lean.Meta.Grind.Params.eraseInj (params : Grind.Params) (declName : Name) : Grind.Params := + { params with extensions := params.extensions.modify 0 fun ext => { ext with inj := ext.inj.erase (.decl declName) } } + +def _root_.Lean.Meta.Grind.ExtensionStateArray.getKindsFor (s : Grind.ExtensionStateArray) (origin : Grind.Origin) : List Grind.EMatchTheoremKind := Id.run do + let mut result := [] + for ext in s do + let s : Grind.EMatchTheorems := ext.ematch + let ks := s.getKindsFor origin + unless ks.isEmpty do + result := result ++ ks + return result + +def warnRedundantEMatchArg (s : Grind.ExtensionStateArray) (declName : Name) : MetaM Unit := do let minIndexable := false -- TODO: infer it let kinds ← match s.getKindsFor (.decl declName) with | [] => return () @@ -54,9 +96,9 @@ public def addEMatchTheorem (params : Grind.Params) (id : Ident) (declName : Nam let thm₁ ← Grind.mkEMatchTheoremForDecl declName (.eqLhs gen) params.symPrios let thm₂ ← Grind.mkEMatchTheoremForDecl declName (.eqRhs gen) params.symPrios if warn && - params.ematch.containsWithSamePatterns thm₁.origin thm₁.patterns thm₁.cnstrs && - params.ematch.containsWithSamePatterns thm₂.origin thm₂.patterns thm₂.cnstrs then - warnRedundantEMatchArg params.ematch declName + params.extensions.containsWithSamePatterns thm₁.origin thm₁.patterns thm₁.cnstrs && + params.extensions.containsWithSamePatterns thm₂.origin thm₂.patterns thm₂.cnstrs then + warnRedundantEMatchArg params.extensions declName return { params with extra := params.extra.push thm₁ |>.push thm₂ } | _ => if kind matches .eqLhs _ | .eqRhs _ then @@ -65,8 +107,8 @@ public def addEMatchTheorem (params : Grind.Params) (id : Ident) (declName : Nam Grind.mkEMatchTheoremAndSuggest id declName params.symPrios minIndexable (isParam := true) else Grind.mkEMatchTheoremForDecl declName kind params.symPrios (minIndexable := minIndexable) - if warn && params.ematch.containsWithSamePatterns thm.origin thm.patterns thm.cnstrs then - warnRedundantEMatchArg params.ematch declName + if warn && params.extensions.containsWithSamePatterns thm.origin thm.patterns thm.cnstrs then + warnRedundantEMatchArg params.extensions declName return { params with extra := params.extra.push thm } | .defn => if (← isReducible declName) then @@ -154,9 +196,13 @@ def processParam (params : Grind.Params) catch err => if (← resolveLocalName id.getId).isSome then throwErrorAt id "redundant parameter `{id}`, `grind` uses local hypotheses automatically" + else if let some ext ← Grind.getExtension? id.getId then + if let some mod := mod? then + throwErrorAt mod "invalid use of modifier in `grind` attribute `{id.getId}`" + return { params with extensions := params.extensions.push (ext.getState (← getEnv)) } else if !id.getId.getPrefix.isAnonymous then -- Fall back to term elaboration for compound identifiers like `foo.le` (dot notation on declarations) - return ← processTermParam params p mod? id minIndexable + return (← processTermParam params p mod? id minIndexable) else throw err Linter.checkDeprecated declName @@ -179,7 +225,7 @@ def processParam (params : Grind.Params) if incremental then throwError "`cases` parameter are not supported here" ensureNoMinIndexable minIndexable withRef p <| Grind.validateCasesAttr declName eager - params := { params with casesTypes := params.casesTypes.insert declName eager } + params := params.insertCasesTypes declName eager | .intro => if let some info ← Grind.isCasesAttrPredicateCandidate? declName false then if incremental then throwError "`cases` parameter are not supported here" @@ -194,7 +240,7 @@ def processParam (params : Grind.Params) throwError "`[grind ext]` cannot be set using parameters" | .infer => if let some declName ← Grind.isCasesAttrCandidate? declName false then - params := { params with casesTypes := params.casesTypes.insert declName false } + params := params.insertCasesTypes declName false if let some info ← isInductivePredicate? declName then -- If it is an inductive predicate, -- we also add the constructors (intro rules) as E-matching rules @@ -207,7 +253,7 @@ def processParam (params : Grind.Params) ensureNoMinIndexable minIndexable params := { params with symPrios := params.symPrios.insert declName prio } | .funCC => - params := { params with funCCs := params.funCCs.insert declName } + params := params.insertFunCC declName return params /-- @@ -228,11 +274,11 @@ public def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.T Linter.checkDeprecated declName if let some declName ← Grind.isCasesAttrCandidate? declName false then Grind.ensureNotBuiltinCases declName - params := { params with casesTypes := (← params.casesTypes.eraseDecl declName) } + params ← params.eraseCasesTypes declName else if (← Grind.isInjectiveTheorem declName) then - params := { params with inj := params.inj.erase (.decl declName) } + params := params.eraseInj declName else - params := { params with ematch := (← params.ematch.eraseDecl declName) } + params ← params.eraseEMatch declName | `(Parser.Tactic.grindParam| $[$mod?:grindMod]? $id:ident) => -- Check if this is dot notation on a local variable (e.g., `n.triv` for `Nat.triv n`). -- If so, process as term to let elaboration resolve the dot notation properly. @@ -294,7 +340,7 @@ public def withParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic let mut params := params if only then params := { params with - ematch := {} + extensions := params.extensions.map fun ext => { ext with ematch := {} } anchorRefs? := none } params ← elabGrindParams params ps (only := only) (incremental := true) diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index d389fa19e9..d9255e9950 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -48,6 +48,7 @@ public import Lean.Meta.Tactic.Grind.Filter public import Lean.Meta.Tactic.Grind.CollectParams public import Lean.Meta.Tactic.Grind.Finish public import Lean.Meta.Tactic.Grind.FunCC +public import Lean.Meta.Tactic.Grind.RegisterCommand public section namespace Lean diff --git a/src/Lean/Meta/Tactic/Grind/Attr.lean b/src/Lean/Meta/Tactic/Grind/Attr.lean index 768a48f995..98e9ee8c88 100644 --- a/src/Lean/Meta/Tactic/Grind/Attr.lean +++ b/src/Lean/Meta/Tactic/Grind/Attr.lean @@ -63,40 +63,89 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do def throwInvalidUsrModifier : CoreM α := throwError "the modifier `usr` is only relevant in parameters for `grind only`" +private def Extension.addCasesAttr (ext : Extension) (declName : Name) (eager : Bool) (attrKind : AttributeKind) : CoreM Unit := do + validateCasesAttr declName eager + ext.add (.cases declName eager) attrKind + +private def Extension.addExtAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do + validateExtAttr declName + ext.add (.ext declName) attrKind + +private def Extension.addFunCCAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do + ext.add (.funCC declName) attrKind + +private def Extension.eraseExtAttr (ext : Extension) (declName : Name) : CoreM Unit := do + let s := ext.getState (← getEnv) + let extThms ← s.extThms.eraseDecl declName + modifyEnv fun env => ext.modifyState env fun s => { s with extThms } + +private def Extension.eraseCasesAttr (ext : Extension) (declName : Name) : CoreM Unit := do + ensureNotBuiltinCases declName + let s := ext.getState (← getEnv) + let casesTypes ← s.casesTypes.eraseDecl declName + modifyEnv fun env => ext.modifyState env fun s => { s with casesTypes } + +private def Extension.eraseFunCCAttr (ext : Extension) (declName : Name) : CoreM Unit := do + let s := ext.getState (← getEnv) + unless s.funCC.contains declName do + throwNotMarkedWithGrindAttribute declName + let funCC := s.funCC.erase declName + modifyEnv fun env => ext.modifyState env fun s => { s with funCC } + +private def Extension.eraseEMatchAttr (ext : Extension) (declName : Name) : MetaM Unit := do + let s := ext.getState (← getEnv) + let ematch ← s.ematch.eraseDecl declName + modifyEnv fun env => ext.modifyState env fun s => { s with ematch } + +private def Extension.eraseInjectiveAttr (ext : Extension) (declName : Name) : MetaM Unit := do + let s := ext.getState (← getEnv) + let inj ← s.inj.eraseDecl declName + modifyEnv fun env => ext.modifyState env fun s => { s with inj } + +private def Extension.isExtTheorem (ext : Extension) (declName : Name) : CoreM Bool := do + return ext.getState (← getEnv) |>.extThms.contains declName + +private def Extension.isInjectiveTheorem (ext : Extension) (declName : Name) : CoreM Bool := do + return ext.getState (← getEnv) |>.inj.contains (.decl declName) + +private def Extension.hasFunCCAttr (ext : Extension) (declName : Name) : CoreM Bool := do + return ext.getState (← getEnv) |>.funCC.contains declName + /-- Auxiliary function for registering `grind`, `grind!`, `grind?`, and `grind!?` attributes. `grind!` is like `grind` but selects minimal indexable subterms. The `grind?` and `grind!?` are aliases for `grind` and `grind!` which displays patterns using `logInfo`. It is just a convenience for users. -/ -private def registerGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit := +private def mkGrindAttr (attrName : Name) (minIndexable : Bool) (showInfo : Bool) (ext? : Option Extension := none) (ref : Name := by exact decl_name%) : IO Unit := registerBuiltinAttribute { + ref := ref name := match minIndexable, showInfo with - | false, false => `grind - | false, true => `grind? - | true, false => `grind! - | true, true => `grind!? + | false, false => attrName + | false, true => attrName.appendAfter "?" + | true, false => attrName.appendAfter "!" + | true, true => attrName.appendAfter "!?" descr := let header := match minIndexable, showInfo with - | false, false => "The `[grind]` attribute is used to annotate declarations." - | false, true => "The `[grind?]` attribute is identical to the `[grind]` attribute, but displays inferred pattern information." - | true, false => "The `[grind!]` attribute is used to annotate declarations, but selecting minimal indexable subterms." - | true, true => "The `[grind!?]` attribute is identical to the `[grind!]` attribute, but displays inferred pattern information." - header ++ "\ + | false, false => s!"The `[{attrName}]` attribute is used to annotate declarations." + | false, true => s!"The `[{attrName}?]` attribute is identical to the `[{attrName}]` attribute, but displays inferred pattern information." + | true, false => s!"The `[{attrName}!]` attribute is used to annotate declarations, but selecting minimal indexable subterms." + | true, true => s!"The `[{attrName}!?]` attribute is identical to the `[{attrName}!]` attribute, but displays inferred pattern information." + header ++ s!"\ \ - When applied to an equational theorem, `[grind =]`, `[grind =_]`, or `[grind _=_]`\ - will mark the theorem for use in heuristic instantiations by the `grind` tactic, + When applied to an equational theorem, `[{attrName} =]`, `[{attrName} =_]`, or `[{attrName} _=_]`\ + will mark the theorem for use in heuristic instantiations by the `{attrName}` tactic, using respectively the left-hand side, the right-hand side, or both sides of the theorem.\ - When applied to a function, `[grind =]` automatically annotates the equational theorems associated with that function.\ - When applied to a theorem `[grind ←]` will instantiate the theorem whenever it encounters the conclusion of the theorem + When applied to a function, `[{attrName} =]` automatically annotates the equational theorems associated with that function.\ + When applied to a theorem `[{attrName} ←]` will instantiate the theorem whenever it encounters the conclusion of the theorem (that is, it will use the theorem for backwards reasoning).\ - When applied to a theorem `[grind →]` will instantiate the theorem whenever it encounters sufficiently many of the propositional hypotheses + When applied to a theorem `[{attrName} →]` will instantiate the theorem whenever it encounters sufficiently many of the propositional hypotheses (that is, it will use the theorem for forwards reasoning).\ \ - The attribute `[grind]` by itself will effectively try `[grind ←]` (if the conclusion is sufficient for instantiation) and then `[grind →]`.\ + The attribute `[{attrName}]` by itself will effectively try `[{attrName} ←]` (if the conclusion is sufficient for instantiation) and then `[{attrName} →]`.\ \ The `grind` tactic utilizes annotated theorems to add instances of matching patterns into the local context during proof search.\ - For example, if a theorem `@[grind =] theorem foo_idempotent : foo (foo x) = foo x` is annotated,\ + For example, if a theorem `@[{attrName} =] theorem foo_idempotent : foo (foo x) = foo x` is annotated,\ `grind` will add an instance of this theorem to the local context whenever it encounters the pattern `foo (foo x)`." applicationTime := .afterCompilation add := fun declName stx attrKind => MetaM.run' do @@ -105,49 +154,111 @@ private def registerGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit -- When the body is not available (i.e. the def equations are private), the attribute will not -- be exported; see `ematchTheoremsExt.exportEntry?`. withoutExporting do - match (← getAttrKindFromOpt stx) with - | .ematch .user => throwInvalidUsrModifier - | .ematch k => addEMatchAttr declName attrKind k (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) - | .cases eager => addCasesAttr declName eager attrKind - | .intro => - if let some info ← isCasesAttrPredicateCandidate? declName false then - for ctor in info.ctors do - addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) - else - throwError "invalid `[grind intro]`, `{.ofConstName declName}` is not an inductive predicate" - | .ext => addExtAttr declName attrKind - | .infer => - if let some declName ← isCasesAttrCandidate? declName false then - addCasesAttr declName false attrKind - if let some info ← isInductivePredicate? declName then - -- If it is an inductive predicate, - -- we also add the constructors (intro rules) as E-matching rules + if let some ext := ext? then + match (← getAttrKindFromOpt stx) with + | .symbol prio => + unless attrName == `grind do + throwError "symbol priorities must be set using the default `[grind]` attribute" + addSymbolPriorityAttr declName attrKind prio + | .cases eager => ext.addCasesAttr declName eager attrKind + | .funCC => ext.addFunCCAttr declName attrKind + | .ext => ext.addExtAttr declName attrKind + | .ematch .user => throwInvalidUsrModifier + | .ematch k => ext.addEMatchAttr declName attrKind k (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) + | .intro => + if let some info ← isCasesAttrPredicateCandidate? declName false then + for ctor in info.ctors do + ext.addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) + else + throwError "invalid `[{attrName} intro]`, `{.ofConstName declName}` is not an inductive predicate" + | .infer => + if let some declName ← isCasesAttrCandidate? declName false then + ext.addCasesAttr declName false attrKind + if let some info ← isInductivePredicate? declName then + -- If it is an inductive predicate, + -- we also add the constructors (intro rules) as E-matching rules + for ctor in info.ctors do + ext.addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) + else + ext.addEMatchAttrAndSuggest stx declName attrKind (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) + | .inj => ext.addInjectiveAttr declName attrKind + else + -- **TODO**: delete after update stage0 and new extension for default `grind` attribute + match (← getAttrKindFromOpt stx) with + | .ematch .user => throwInvalidUsrModifier + | .ematch k => addEMatchAttr declName attrKind k (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) + | .cases eager => addCasesAttr declName eager attrKind + | .intro => + if let some info ← isCasesAttrPredicateCandidate? declName false then for ctor in info.ctors do addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) - else - addEMatchAttrAndSuggest stx declName attrKind (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) - | .symbol prio => addSymbolPriorityAttr declName attrKind prio - | .inj => addInjectiveAttr declName attrKind - | .funCC => addFunCCAttr declName attrKind + else + throwError "invalid `[{attrName} intro]`, `{.ofConstName declName}` is not an inductive predicate" + | .ext => addExtAttr declName attrKind + | .infer => + if let some declName ← isCasesAttrCandidate? declName false then + addCasesAttr declName false attrKind + if let some info ← isInductivePredicate? declName then + -- If it is an inductive predicate, + -- we also add the constructors (intro rules) as E-matching rules + for ctor in info.ctors do + addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) + else + addEMatchAttrAndSuggest stx declName attrKind (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo) + | .symbol prio => addSymbolPriorityAttr declName attrKind prio + | .inj => addInjectiveAttr declName attrKind + | .funCC => addFunCCAttr declName attrKind erase := fun declName => MetaM.run' do if showInfo then - throwError "`[grind?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[grind]` instead" - if (← isCasesAttrCandidate declName false) then - eraseCasesAttr declName - else if (← isExtTheorem declName) then - eraseExtAttr declName - else if (← isInjectiveTheorem declName) then - eraseInjectiveAttr declName - else if (← hasFunCCAttr declName) then - eraseFunCCAttr declName + throwError "`[{attrName}?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[{attrName}]` instead" + if let some ext := ext? then + if (← isCasesAttrCandidate declName false) then + ext.eraseCasesAttr declName + else if (← ext.isExtTheorem declName) then + ext.eraseExtAttr declName + else if (← ext.isInjectiveTheorem declName) then + ext.eraseInjectiveAttr declName + else if (← ext.hasFunCCAttr declName) then + ext.eraseFunCCAttr declName + else + ext.eraseEMatchAttr declName else - eraseEMatchAttr declName + -- **TODO**: delete after update stage0 and new extension for default `grind` attribute + if (← isCasesAttrCandidate declName false) then + eraseCasesAttr declName + else if (← isExtTheorem declName) then + eraseExtAttr declName + else if (← isInjectiveTheorem declName) then + eraseInjectiveAttr declName + else if (← hasFunCCAttr declName) then + eraseFunCCAttr declName + else + eraseEMatchAttr declName } +private def registerDefaultGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit := + mkGrindAttr `grind minIndexable showInfo + builtin_initialize - registerGrindAttr (minIndexable := false) (showInfo := true) - registerGrindAttr (minIndexable := false) (showInfo := false) - registerGrindAttr (minIndexable := true) (showInfo := true) - registerGrindAttr (minIndexable := true) (showInfo := false) + registerDefaultGrindAttr (minIndexable := false) (showInfo := true) + registerDefaultGrindAttr (minIndexable := false) (showInfo := false) + registerDefaultGrindAttr (minIndexable := true) (showInfo := true) + registerDefaultGrindAttr (minIndexable := true) (showInfo := false) + +abbrev ExtensionMap := Std.HashMap Name Extension + +builtin_initialize extensionMapRef : IO.Ref ExtensionMap ← IO.mkRef {} + +def getExtension? (attrName : Name) : IO (Option Extension) := + return (← extensionMapRef.get)[attrName]? + +def registerAttr (attrName : Name) (ref : Name := by exact decl_name%) : IO Extension := do + let ext ← mkExtension ref + mkGrindAttr attrName (minIndexable := false) (showInfo := true) (ext? := some ext) (ref := ref) + mkGrindAttr attrName (minIndexable := false) (showInfo := false) (ext? := some ext) (ref := ref) + mkGrindAttr attrName (minIndexable := true) (showInfo := true) (ext? := some ext) (ref := ref) + mkGrindAttr attrName (minIndexable := true) (showInfo := false) (ext? := some ext) (ref := ref) + extensionMapRef.modify fun map => map.insert attrName ext + return ext end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Cases.lean b/src/Lean/Meta/Tactic/Grind/Cases.lean index 01c41a6112..59179f3b67 100644 --- a/src/Lean/Meta/Tactic/Grind/Cases.lean +++ b/src/Lean/Meta/Tactic/Grind/Cases.lean @@ -6,19 +6,18 @@ Authors: Leonardo de Moura module prelude public import Lean.Meta.Tactic.Cases +public import Lean.Meta.Tactic.Grind.Extension public section namespace Lean.Meta.Grind - -/-- Types that `grind` will case-split on. -/ -structure CasesTypes where - casesMap : PHashMap Name Bool := {} - deriving Inhabited - +-- TODO: delete structure CasesEntry where declName : Name eager : Bool deriving Inhabited +/-- A collection of `CasesTypes`. -/ +abbrev CasesTypesArray := Array CasesTypes + /-- `grind` always case-splits on the following types. Even when using `grind only`. The goal is to reduce noise in the tactic generated by `grind?` @@ -43,9 +42,6 @@ def CasesTypes.contains (s : CasesTypes) (declName : Name) : Bool := def CasesTypes.erase (s : CasesTypes) (declName : Name) : CasesTypes := { s with casesMap := s.casesMap.erase declName } -def CasesTypes.insert (s : CasesTypes) (declName : Name) (eager : Bool) : CasesTypes := - { s with casesMap := s.casesMap.insert declName eager } - def CasesTypes.find? (s : CasesTypes) (declName : Name) : Option Bool := s.casesMap.find? declName @@ -55,6 +51,9 @@ def CasesTypes.isEagerSplit (s : CasesTypes) (declName : Name) : Bool := def CasesTypes.isSplit (s : CasesTypes) (declName : Name) : Bool := (s.casesMap.find? declName |>.isSome) || isBuiltinEagerCases declName +/- +TODO: group into a `grind` extension object +-/ builtin_initialize casesExt : SimpleScopedEnvExtension CasesEntry CasesTypes ← registerSimpleScopedEnvExtension { initial := {} @@ -68,7 +67,7 @@ def getCasesTypes : CoreM CasesTypes := return casesExt.getState (← getEnv) /-- Returns `true` is `declName` is a builtin split or has been tagged with `[grind]` attribute. -/ -def isSplit (declName : Name) : CoreM Bool := do +def isGlobalSplit (declName : Name) : CoreM Bool := do return (← getCasesTypes).isSplit declName partial def isCasesAttrCandidate? (declName : Name) (eager : Bool) : CoreM (Option Name) := do @@ -98,7 +97,7 @@ def CasesTypes.eraseDecl (s : CasesTypes) (declName : Name) : CoreM CasesTypes : if s.contains declName then return s.erase declName else - throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute" + throwNotMarkedWithGrindAttribute declName def ensureNotBuiltinCases (declName : Name) : CoreM Unit := do if isBuiltinEagerCases declName then diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index 57364321d0..3331329aec 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -5,7 +5,7 @@ Authors: Leonardo de Moura -/ module prelude -public import Lean.Meta.Tactic.Grind.Theorems +public import Lean.Meta.Tactic.Grind.Extension import Init.Grind.Util import Lean.Util.ForEachExpr import Lean.Meta.Tactic.Grind.Util @@ -13,14 +13,7 @@ import Lean.Meta.Match.Basic import Lean.Meta.Tactic.TryThis public section namespace Lean.Meta.Grind -/-- -`grind` uses symbol priorities when inferring patterns for E-matching. -Symbols not in `map` are assumed to have default priority (i.e., `eval_prio default`). --/ -structure SymbolPriorities where - map : PHashMap Name Nat := {} - deriving Inhabited - +-- TODO: delete structure SymbolPriorityEntry where declName : Name prio : Nat @@ -30,10 +23,6 @@ structure SymbolPriorityEntry where def SymbolPriorities.erase (s : SymbolPriorities) (declName : Name) : SymbolPriorities := { s with map := s.map.erase declName } -/-- Inserts `declName ↦ prio` into `s`. -/ -def SymbolPriorities.insert (s : SymbolPriorities) (declName : Name) (prio : Nat) : SymbolPriorities := - { s with map := s.map.insert declName prio } - /-- Returns `declName` priority for E-matching pattern inference in `s`. -/ def SymbolPriorities.getPrio (s : SymbolPriorities) (declName : Name) : Nat := if let some prio := s.map.find? declName then @@ -48,6 +37,9 @@ Recall that symbols not in `s` are assumed to have default priority. def SymbolPriorities.contains (s : SymbolPriorities) (declName : Name) : Bool := s.map.contains declName +/- +TODO: group into a `grind` extension object +-/ private builtin_initialize symbolPrioExt : SimpleScopedEnvExtension SymbolPriorityEntry SymbolPriorities ← registerSimpleScopedEnvExtension { initial := {} @@ -286,19 +278,6 @@ def preprocessPattern (pat : Expr) (normalizePattern := true) : MetaM Expr := do let pat ← foldProjs pat return pat -inductive EMatchTheoremKind where - | eqLhs (gen : Bool) - | eqRhs (gen : Bool) - | eqBoth (gen : Bool) - | eqBwd - | fwd - | bwd (gen : Bool) - | leftRight - | rightLeft - | default (gen : Bool) - | user /- pattern specified using `grind_pattern` command -/ - deriving Inhabited, BEq, Repr, Hashable - def EMatchTheoremKind.isEqLhs : EMatchTheoremKind → Bool | .eqLhs _ => true | _ => false @@ -345,106 +324,13 @@ private def EMatchTheoremKind.explainFailure : EMatchTheoremKind → String | .default _ => "failed to find patterns" | .user => unreachable! -structure CnstrRHS where - /-- Abstracted universe level param names in the `rhs` -/ - levelNames : Array Name - /-- Number of abstracted metavariable in the `rhs` -/ - numMVars : Nat - /-- The actual `rhs`. -/ - expr : Expr - deriving Inhabited, BEq, Repr - -/-- -Grind patterns may have constraints associated with them. --/ -inductive EMatchTheoremConstraint where - | /-- - A constraint of the form `lhs =/= rhs`. - The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally - equal to a term `t` assigned to `lhs`. -/ - notDefEq (lhs : Nat) (rhs : CnstrRHS) - | /-- - A constraint of the form `lhs =?= rhs`. - The `lhs` is one of the bound variables, and the `rhs` an abstract term that must be definitionally - equal to a term `t` assigned to `lhs`. -/ - defEq (lhs : Nat) (rhs : CnstrRHS) - | /-- - A constraint of the form `size lhs < n`. The `lhs` is one of the bound variables. - The size is computed ignoring implicit terms, but sharing is not taken into account. - -/ - sizeLt (lhs : Nat) (n : Nat) - | /-- - A constraint of the form `depth lhs < n`. The `lhs` is one of the bound variables. - The depth is computed in constant time using the `approxDepth` field attached to expressions. - -/ - depthLt (lhs : Nat) (n : Nat) - | /-- - Instantiates the theorem only if its generation is less than `n` - -/ - genLt (n : Nat) - | /-- - Constraints of the form `is_ground x`. Instantiates the theorem only if - `x` is ground term. - -/ - isGround (bvarIdx : Nat) - | /-- - Constraints of the form `is_value x` and `is_strict_value x`. - A value is defined as - - A constructor fully applied to value arguments. - - A literal: numerals, strings, etc. - - A lambda. In the strict case, lambdas are not considered. - -/ - isValue (bvarIdx : Nat) (strict : Bool) - | /-- - Instantiates the theorem only if less than `n` instances have been generated for this theorem. - -/ - maxInsts (n : Nat) - | /-- - It instructs `grind` to postpone the instantiation of the theorem until `e` is known to be `true`. - -/ - guard (e : Expr) - | /-- - Similar to `guard`, but checks whether `e` is implied by asserting `¬e`. - -/ - check (e : Expr) - | /-- - Constraints of the form `not_value x` and `not_strict_value x`. - They are the negations of `is_value x` and `is_strict_value x`. - -/ - notValue (bvarIdx : Nat) (strict : Bool) - deriving Inhabited, Repr, BEq - -/-- A theorem for heuristic instantiation based on E-matching. -/ -structure EMatchTheorem where - /-- - It stores universe parameter names for universe polymorphic proofs. - Recall that it is non-empty only when we elaborate an expression provided by the user. - When `proof` is just a constant, we can use the universe parameter names stored in the declaration. - -/ - levelParams : Array Name - proof : Expr - numParams : Nat - patterns : List Expr - /-- Contains all symbols used in `patterns`. -/ - symbols : List HeadIndex - origin : Origin - /-- The `kind` is used for generating the `patterns`. We save it here to implement `grind?`. -/ - kind : EMatchTheoremKind - /-- Stores whether patterns were inferred using the minimal indexable subexpression condition. -/ - minIndexable : Bool - cnstrs : List EMatchTheoremConstraint := [] - deriving Inhabited - -instance : TheoremLike EMatchTheorem where - getSymbols thm := thm.symbols - setSymbols thm symbols := { thm with symbols } - getOrigin thm := thm.origin - getProof thm := thm.proof - getLevelParams thm := thm.levelParams /-- Set of E-matching theorems. -/ abbrev EMatchTheorems := Theorems EMatchTheorem +/-- A collection of sets of E-matching theorems. -/ +abbrev EMatchTheoremsArray := TheoremsArray EMatchTheorem + /-- Returns `true` if there is a theorem with exactly the same pattern and constraints is already in `s` -/ @@ -453,6 +339,10 @@ def EMatchTheorems.containsWithSamePatterns (s : EMatchTheorems) (origin : Origi let thms := s.find origin thms.any fun thm => thm.patterns == patterns && thm.cnstrs == cnstrs +def ExtensionStateArray.containsWithSamePatterns (s : ExtensionStateArray) (origin : Origin) + (patterns : List Expr) (cnstrs : List EMatchTheoremConstraint) : Bool := + s.any (EMatchTheorems.containsWithSamePatterns ·.ematch origin patterns cnstrs) + def EMatchTheorems.getKindsFor (s : EMatchTheorems) (origin : Origin) : List EMatchTheoremKind := let thms := s.find origin thms.map (·.kind) @@ -460,6 +350,9 @@ def EMatchTheorems.getKindsFor (s : EMatchTheorems) (origin : Origin) : List EMa def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do Grind.getProofWithFreshMVarLevels thm +/- +TODO: group into a `grind` extension object +-/ private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTheorem (Theorems EMatchTheorem) ← registerSimpleScopedEnvExtension { addEntry := Theorems.insert @@ -1482,6 +1375,7 @@ def mkEMatchEqTheoremsForDef? (declName : Name) (showInfo := false) : MetaM (Opt eqns.mapM fun eqn => do mkEMatchEqTheorem eqn (normalizePattern := true) (showInfo := showInfo) +-- TODO: delete private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do if wasOriginallyTheorem (← getEnv) declName then ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (gen := thmKind.gen) (showInfo := showInfo)) attrKind @@ -1492,26 +1386,34 @@ private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind else throwError s!"`{thmKind.toAttribute false}` attribute can only be applied to equational theorems or function definitions" +private def Extension.addGrindEqAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do + if wasOriginallyTheorem (← getEnv) declName then + ext.add (.ematch (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (gen := thmKind.gen) (showInfo := showInfo))) attrKind + else if let some thms ← mkEMatchEqTheoremsForDef? declName (showInfo := showInfo) then + unless useLhs do + throwError "`{.ofConstName declName}` is a definition, you must only use the left-hand side for extracting patterns" + thms.forM fun thm => ext.add (.ematch thm) attrKind + else + throwError s!"`{thmKind.toAttribute false}` attribute can only be applied to equational theorems or function definitions" + def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMatchTheorems := do - let throwErr {α} : MetaM α := - throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute" if !wasOriginallyTheorem (← getEnv) declName then if let some eqns ← getEqnsFor? declName then - let s := ematchTheoremsExt.getState (← getEnv) unless eqns.all fun eqn => s.contains (.decl eqn) do - throwErr + throwNotMarkedWithGrindAttribute declName return eqns.foldl (init := s) fun s eqn => s.erase (.decl eqn) else - throwErr + throwNotMarkedWithGrindAttribute declName else - unless ematchTheoremsExt.getState (← getEnv) |>.contains (.decl declName) do - throwErr + unless s.contains (.decl declName) do + throwNotMarkedWithGrindAttribute declName return s.erase <| .decl declName private def ensureNoMinIndexable (minIndexable : Bool) : MetaM Unit := do if minIndexable then throwError "redundant modifier `!` in `grind` attribute" +-- TODO: delete def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (prios : SymbolPriorities) (showInfo := false) (minIndexable := false) : MetaM Unit := do match thmKind with @@ -1534,6 +1436,28 @@ def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatch let thm ← mkEMatchTheoremForDecl declName thmKind prios (showInfo := showInfo) (minIndexable := minIndexable) ematchTheoremsExt.add thm attrKind +def Extension.addEMatchAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (prios : SymbolPriorities) + (showInfo := false) (minIndexable := false) : MetaM Unit := do + match thmKind with + | .eqLhs _ => + ensureNoMinIndexable minIndexable + ext.addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo) + | .eqRhs _ => + ensureNoMinIndexable minIndexable + ext.addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo) + | .eqBoth _ => + ensureNoMinIndexable minIndexable + ext.addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo) + ext.addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo) + | _ => + let info ← getConstInfo declName + if !wasOriginallyTheorem (← getEnv) declName && !info.isCtor && !info.isAxiom then + ensureNoMinIndexable minIndexable + ext.addGrindEqAttr declName attrKind thmKind (showInfo := showInfo) + else + let thm ← mkEMatchTheoremForDecl declName thmKind prios (showInfo := showInfo) (minIndexable := minIndexable) + ext.add (.ematch thm) attrKind + private structure SelectM.State where -- **Note**: hack, an attribute is not a tactic. suggestions : Array Tactic.TryThis.Suggestion := #[] @@ -1665,6 +1589,7 @@ Tries different modifiers, logs info messages with modifiers that worked, but st Remark: if `backward.grind.inferPattern` is `true`, then `.default false` is used. The parameter `showInfo` is only taken into account when `backward.grind.inferPattern` is `true`. -/ +-- TODO: delete def addEMatchAttrAndSuggest (ref : Syntax) (declName : Name) (attrKind : AttributeKind) (prios : SymbolPriorities) (minIndexable : Bool) (showInfo : Bool) (isParam : Bool := false) : MetaM Unit := do let info ← getConstInfo declName @@ -1677,6 +1602,25 @@ def addEMatchAttrAndSuggest (ref : Syntax) (declName : Name) (attrKind : Attribu let thm ← mkEMatchTheoremAndSuggest ref declName prios minIndexable isParam ematchTheoremsExt.add thm attrKind +/-- +Tries different modifiers, logs info messages with modifiers that worked, but stores just the first one that worked. + +Remark: if `backward.grind.inferPattern` is `true`, then `.default false` is used. +The parameter `showInfo` is only taken into account when `backward.grind.inferPattern` is `true`. +-/ +-- TODO: delete +def Extension.addEMatchAttrAndSuggest (ext : Extension) (ref : Syntax) (declName : Name) (attrKind : AttributeKind) (prios : SymbolPriorities) + (minIndexable : Bool) (showInfo : Bool) (isParam : Bool := false) : MetaM Unit := do + let info ← getConstInfo declName + if !wasOriginallyTheorem (← getEnv) declName && !info.isCtor && !info.isAxiom then + ensureNoMinIndexable minIndexable + ext.addGrindEqAttr declName attrKind (.default false) (showInfo := showInfo) + else if backward.grind.inferPattern.get (← getOptions) then + ext.addEMatchAttr declName attrKind (.default false) prios (minIndexable := minIndexable) (showInfo := showInfo) + else + let thm ← mkEMatchTheoremAndSuggest ref declName prios minIndexable isParam + ext.add (.ematch thm) attrKind + def eraseEMatchAttr (declName : Name) : MetaM Unit := do /- Remark: consider the following example diff --git a/src/Lean/Meta/Tactic/Grind/ExtAttr.lean b/src/Lean/Meta/Tactic/Grind/ExtAttr.lean index 6f8dc50f6c..86305941ec 100644 --- a/src/Lean/Meta/Tactic/Grind/ExtAttr.lean +++ b/src/Lean/Meta/Tactic/Grind/ExtAttr.lean @@ -6,13 +6,14 @@ Authors: Leonardo de Moura module prelude public import Lean.Meta.Tactic.Ext +public import Lean.Meta.Tactic.Grind.Extension public section namespace Lean.Meta.Grind /-! Grind extensionality attribute to mark which `[ext]` theorems should be used. -/ -/-- Extensionality theorems that can be used by `grind` -/ -abbrev ExtTheorems := PHashSet Name - +/- +TODO: group into a `grind` extension object +-/ builtin_initialize extTheoremsExt : SimpleScopedEnvExtension Name ExtTheorems ← registerSimpleScopedEnvExtension { initial := {} @@ -28,7 +29,7 @@ def addExtAttr (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do validateExtAttr declName extTheoremsExt.add declName attrKind -private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do +def ExtTheorems.eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do if s.contains declName then return s.erase declName else @@ -36,10 +37,13 @@ private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := def eraseExtAttr (declName : Name) : CoreM Unit := do let s := extTheoremsExt.getState (← getEnv) - let s ← eraseDecl s declName + let s ← s.eraseDecl declName modifyEnv fun env => extTheoremsExt.modifyState env fun _ => s def isExtTheorem (declName : Name) : CoreM Bool := do return extTheoremsExt.getState (← getEnv) |>.contains declName +def getGlobalExtTheorems : CoreM ExtTheorems := do + return extTheoremsExt.getState (← getEnv) + end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Extension.lean b/src/Lean/Meta/Tactic/Grind/Extension.lean new file mode 100644 index 0000000000..d4101398b8 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Extension.lean @@ -0,0 +1,221 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Expr +public import Lean.Data.PersistentHashMap +public import Lean.Meta.Tactic.Grind.Theorems +public section +namespace Lean.Meta.Grind + +/-- Types that `grind` will case-split on. -/ +structure CasesTypes where + casesMap : PHashMap Name Bool := {} + deriving Inhabited + +def CasesTypes.insert (s : CasesTypes) (declName : Name) (eager : Bool) : CasesTypes := + { s with casesMap := s.casesMap.insert declName eager } + +abbrev ExtTheorems := PHashSet Name + +structure SymbolPriorities where + map : PHashMap Name Nat := {} + deriving Inhabited + +/-- Inserts `declName ↦ prio` into `s`. -/ +def SymbolPriorities.insert (s : SymbolPriorities) (declName : Name) (prio : Nat) : SymbolPriorities := + { s with map := s.map.insert declName prio } + +inductive EMatchTheoremKind where + | eqLhs (gen : Bool) + | eqRhs (gen : Bool) + | eqBoth (gen : Bool) + | eqBwd + | fwd + | bwd (gen : Bool) + | leftRight + | rightLeft + | default (gen : Bool) + | user /- pattern specified using `grind_pattern` command -/ + deriving Inhabited, BEq, Repr, Hashable + +structure CnstrRHS where + /-- Abstracted universe level param names in the `rhs` -/ + levelNames : Array Name + /-- Number of abstracted metavariable in the `rhs` -/ + numMVars : Nat + /-- The actual `rhs`. -/ + expr : Expr + deriving Inhabited, BEq, Repr + +/-- +Grind patterns may have constraints associated with them. +-/ +inductive EMatchTheoremConstraint where + | /-- + A constraint of the form `lhs =/= rhs`. + The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally + equal to a term `t` assigned to `lhs`. -/ + notDefEq (lhs : Nat) (rhs : CnstrRHS) + | /-- + A constraint of the form `lhs =?= rhs`. + The `lhs` is one of the bound variables, and the `rhs` an abstract term that must be definitionally + equal to a term `t` assigned to `lhs`. -/ + defEq (lhs : Nat) (rhs : CnstrRHS) + | /-- + A constraint of the form `size lhs < n`. The `lhs` is one of the bound variables. + The size is computed ignoring implicit terms, but sharing is not taken into account. + -/ + sizeLt (lhs : Nat) (n : Nat) + | /-- + A constraint of the form `depth lhs < n`. The `lhs` is one of the bound variables. + The depth is computed in constant time using the `approxDepth` field attached to expressions. + -/ + depthLt (lhs : Nat) (n : Nat) + | /-- + Instantiates the theorem only if its generation is less than `n` + -/ + genLt (n : Nat) + | /-- + Constraints of the form `is_ground x`. Instantiates the theorem only if + `x` is ground term. + -/ + isGround (bvarIdx : Nat) + | /-- + Constraints of the form `is_value x` and `is_strict_value x`. + A value is defined as + - A constructor fully applied to value arguments. + - A literal: numerals, strings, etc. + - A lambda. In the strict case, lambdas are not considered. + -/ + isValue (bvarIdx : Nat) (strict : Bool) + | /-- + Instantiates the theorem only if less than `n` instances have been generated for this theorem. + -/ + maxInsts (n : Nat) + | /-- + It instructs `grind` to postpone the instantiation of the theorem until `e` is known to be `true`. + -/ + guard (e : Expr) + | /-- + Similar to `guard`, but checks whether `e` is implied by asserting `¬e`. + -/ + check (e : Expr) + | /-- + Constraints of the form `not_value x` and `not_strict_value x`. + They are the negations of `is_value x` and `is_strict_value x`. + -/ + notValue (bvarIdx : Nat) (strict : Bool) + deriving Inhabited, Repr, BEq + +/-- A theorem for heuristic instantiation based on E-matching. -/ +structure EMatchTheorem where + /-- + It stores universe parameter names for universe polymorphic proofs. + Recall that it is non-empty only when we elaborate an expression provided by the user. + When `proof` is just a constant, we can use the universe parameter names stored in the declaration. + -/ + levelParams : Array Name + proof : Expr + numParams : Nat + patterns : List Expr + /-- Contains all symbols used in `patterns`. -/ + symbols : List HeadIndex + origin : Origin + /-- The `kind` is used for generating the `patterns`. We save it here to implement `grind?`. -/ + kind : EMatchTheoremKind + /-- Stores whether patterns were inferred using the minimal indexable subexpression condition. -/ + minIndexable : Bool + cnstrs : List EMatchTheoremConstraint := [] + deriving Inhabited + +instance : TheoremLike EMatchTheorem where + getSymbols thm := thm.symbols + setSymbols thm symbols := { thm with symbols } + getOrigin thm := thm.origin + getProof thm := thm.proof + getLevelParams thm := thm.levelParams + +/-- A theorem marked with `@[grind inj]` -/ +structure InjectiveTheorem where + levelParams : Array Name + proof : Expr + /-- Contains all symbols used in the term `f` at the theorem's conclusion: `Function.Injective f`. -/ + symbols : List HeadIndex + origin : Origin + deriving Inhabited + +instance : TheoremLike InjectiveTheorem where + getSymbols thm := thm.symbols + setSymbols thm symbols := { thm with symbols } + getOrigin thm := thm.origin + getProof thm := thm.proof + getLevelParams thm := thm.levelParams + +inductive Entry where + | ext (declName : Name) + | funCC (declName : Name) + | cases (declName : Name) (eager : Bool) + | ematch (thm : EMatchTheorem) + | inj (thm : InjectiveTheorem) + deriving Inhabited + +/- +**Note**: We currently have a single normalization and symbol priority sets for all `grind` attributes. +Reason: the E-match patterns must be normalized with respect to them. If we are using multiple +`grind` attributes, they patterns would have to be re-normalized using the union of all normalizers. + +Alternative design: allow a single `grind` attribute per `grind` call. Cons: when creating a new +`grind` attribute users would have to carefully setup the normalizer to ensure all `grind` assumptions +are met. Cons: users would not be able to write `grind only [attr_1, attr_2]`. +-/ + +structure ExtensionState where + casesTypes : CasesTypes := {} + extThms : ExtTheorems := {} + funCC : NameSet := {} + ematch : Theorems EMatchTheorem := {} + inj : Theorems InjectiveTheorem := {} + deriving Inhabited + +abbrev Extension := SimpleScopedEnvExtension Entry ExtensionState + +def ExtensionState.addEntry (s : ExtensionState) (e : Entry) : ExtensionState := + match e with + | .cases declName eager => { s with casesTypes := s.casesTypes.insert declName eager } + | .ext declName => { s with extThms := s.extThms.insert declName } + | .funCC declName => { s with funCC := s.funCC.insert declName } + | .ematch thm => { s with ematch := s.ematch.insert thm } + | .inj thm => { s with inj := s.inj.insert thm } + +def mkExtension (name : Name := by exact decl_name%) : IO Extension := + registerSimpleScopedEnvExtension { + name := name + initial := {} + addEntry := ExtensionState.addEntry + exportEntry? := fun lvl e => do + -- export only annotations on public decls + let declName := match e with + | .inj thm | .ematch thm => + match thm.origin with + | .decl declName => declName + | _ => unreachable! + | .ext declName | .cases declName _ | .funCC declName => declName + guard (lvl == .private || !isPrivateName declName) + return e + } + +/-- +`grind` is parametrized by a collection of `ExtensionState`. The motivation is to allow +users to use multiple extensions simultaneously without merging them into a single structure. +The collection is scanned linearly. In practice, we expect the array to be very small. +-/ +abbrev ExtensionStateArray := Array ExtensionState + +def throwNotMarkedWithGrindAttribute (declName : Name) : CoreM α := + throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute" + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/FunCC.lean b/src/Lean/Meta/Tactic/Grind/FunCC.lean index d862432363..6486a07a24 100644 --- a/src/Lean/Meta/Tactic/Grind/FunCC.lean +++ b/src/Lean/Meta/Tactic/Grind/FunCC.lean @@ -9,6 +9,9 @@ public import Lean.ScopedEnvExtension public section namespace Lean.Meta.Grind +/- +TODO: group into a `grind` extension object +-/ private builtin_initialize funCCExt : SimpleScopedEnvExtension Name NameSet ← registerSimpleScopedEnvExtension { initial := {} diff --git a/src/Lean/Meta/Tactic/Grind/Injective.lean b/src/Lean/Meta/Tactic/Grind/Injective.lean index 7701897614..aa078eaa60 100644 --- a/src/Lean/Meta/Tactic/Grind/Injective.lean +++ b/src/Lean/Meta/Tactic/Grind/Injective.lean @@ -14,25 +14,15 @@ builtin_initialize registerTraceClass `grind.inj builtin_initialize registerTraceClass `grind.inj.assert builtin_initialize registerTraceClass `grind.debug.inj -/-- A theorem marked with `@[grind inj]` -/ -structure InjectiveTheorem where - levelParams : Array Name - proof : Expr - /-- Contains all symbols used in the term `f` at the theorem's conclusion: `Function.Injective f`. -/ - symbols : List HeadIndex - origin : Origin - deriving Inhabited - -instance : TheoremLike InjectiveTheorem where - getSymbols thm := thm.symbols - setSymbols thm symbols := { thm with symbols } - getOrigin thm := thm.origin - getProof thm := thm.proof - getLevelParams thm := thm.levelParams - /-- Set of Injective theorems. -/ abbrev InjectiveTheorems := Theorems InjectiveTheorem +/-- A collections of sets of Injective theorems. -/ +abbrev InjectiveTheoremsArray := TheoremsArray InjectiveTheorem + +/- +TODO: group into a `grind` extension object +-/ private builtin_initialize injectiveTheoremsExt : SimpleScopedEnvExtension InjectiveTheorem (Theorems InjectiveTheorem) ← registerSimpleScopedEnvExtension { addEntry := Theorems.insert @@ -85,9 +75,13 @@ def mkInjectiveTheorem (declName : Name) : MetaM InjectiveTheorem := do proof, symbols } +-- TODO: delete def addInjectiveAttr (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do injectiveTheoremsExt.add (← mkInjectiveTheorem declName) attrKind +def Extension.addInjectiveAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do + ext.add (.inj (← mkInjectiveTheorem declName)) attrKind + def eraseInjectiveAttr (declName : Name) : MetaM Unit := do let s := injectiveTheoremsExt.getState (← getEnv) let s ← s.eraseDecl declName diff --git a/src/Lean/Meta/Tactic/Grind/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Internalize.lean index 5210d009f4..07686a458b 100644 --- a/src/Lean/Meta/Tactic/Grind/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Internalize.lean @@ -146,13 +146,13 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do return () unless (← isInductivePredicate declName) do return () - if (← get).split.casesTypes.isSplit declName then + if (← isSplit declName) then addDefaultSplitCandidate e else if (← getConfig).splitIndPred then addDefaultSplitCandidate e | .fvar .. => let .const declName _ := (← whnf (← inferType e)).getAppFn | return () - if (← get).split.casesTypes.isSplit declName then + if (← isSplit declName) then addDefaultSplitCandidate e | .forallE _ d _ _ => let currSplitSource := (← readThe Context).splitSource @@ -275,8 +275,8 @@ private def addMatchEqns (f : Expr) (generation : Nat) : GoalM Unit := do @[specialize] private def activateTheoremsCore [TheoremLike α] (declName : Name) - (getThms : GoalM (Theorems α)) - (setThms : Theorems α → GoalM Unit) + (getThms : GoalM (TheoremsArray α)) + (setThms : TheoremsArray α → GoalM Unit) (reinsertThm : α → GoalM Unit) (activateThm : α → GoalM Unit) : GoalM Unit := do if let some (thms, s) := (← getThms).retrieve? declName then @@ -444,7 +444,7 @@ private def tryEta (e : Expr) (generation : Nat) : GoalM Unit := do Returns `true` if we should use `funCC` for applications of the given constant symbol. -/ private def useFunCongrAtDecl (declName : Name) : GrindM Bool := do - if (← readThe Grind.Context).funCCs.contains declName then + if (← hasFunCCModifier declName) then return true if (← isInstance declName) then /- **Note**: Instances are support elements. No `funCC` -/ diff --git a/src/Lean/Meta/Tactic/Grind/Intro.lean b/src/Lean/Meta/Tactic/Grind/Intro.lean index a4d8c95cfc..5c29fe2a5b 100644 --- a/src/Lean/Meta/Tactic/Grind/Intro.lean +++ b/src/Lean/Meta/Tactic/Grind/Intro.lean @@ -182,9 +182,9 @@ private partial def introNext (goal : Goal) (generation : Nat) : GrindM IntroRes else return .done goal -private def isEagerCasesCandidate (goal : Goal) (type : Expr) : Bool := Id.run do +private def isEagerCasesCandidate (type : Expr) : GrindM Bool := do let .const declName _ := type.getAppFn | return false - return goal.split.casesTypes.isEagerSplit declName + isEagerSplit declName /-- Returns `true` if `type` is an inductive type with at most one constructor. -/ private def isCheapInductive (type : Expr) : CoreM Bool := do @@ -215,7 +215,7 @@ private def applyCases? (goal : Goal) (fvarId : FVarId) (kp : ActionCont) : Grin Example: `a ∣ b` is defined as `∃ x, b = a * x` -/ let type ← whnf (← fvarId.getType) - unless isEagerCasesCandidate goal type do return none + unless (← isEagerCasesCandidate type) do return none if (← cheapCasesOnly) then unless (← isCheapInductive type) do return none if let .const declName _ := type.getAppFn then @@ -268,7 +268,7 @@ def intros (generation : Nat) : Action := /-- Asserts a new fact `prop` with proof `proof` to the given `goal`. -/ private def assertAt (proof : Expr) (prop : Expr) (generation : Nat) : Action := fun goal kna kp => do - if isEagerCasesCandidate goal prop then + if (← isEagerCasesCandidate prop) then let mvarId ← goal.mvarId.assert (← mkFreshUserName `h) prop proof intros generation { goal with mvarId } kna kp else goal.withContext do diff --git a/src/Lean/Meta/Tactic/Grind/Main.lean b/src/Lean/Meta/Tactic/Grind/Main.lean index 01ee043db8..639c09846a 100644 --- a/src/Lean/Meta/Tactic/Grind/Main.lean +++ b/src/Lean/Meta/Tactic/Grind/Main.lean @@ -32,26 +32,51 @@ import Lean.Meta.Tactic.Grind.Core public section namespace Lean.Meta.Grind +/-- +Returns the `ExtensionState` for the default `grind` attribute. +-/ +def getDefaultExtensionState : MetaM ExtensionState := do + -- **TODO**: update after update stage0 + let casesTypes ← getCasesTypes + let funCC ← getFunCCSet + let extThms ← getGlobalExtTheorems + let ematch ← getEMatchTheorems + let inj ← getInjectiveTheorems + return { + casesTypes, funCC, extThms, ematch, inj + } + +def getOnlyExtensionState : MetaM ExtensionState := do + let casesTypes ← getCasesTypes + let funCC ← getFunCCSet + let extThms ← getGlobalExtTheorems + return { + casesTypes, funCC, extThms + } + structure Params where config : Grind.Config - ematch : EMatchTheorems := default - inj : InjectiveTheorems := default - symPrios : SymbolPriorities := {} - casesTypes : CasesTypes := {} + extensions : ExtensionStateArray := #[] extra : PArray EMatchTheorem := {} extraInj : PArray InjectiveTheorem := {} extraFacts : PArray Expr := {} - funCCs : NameSet := {} + symPrios : SymbolPriorities := {} norm : Simp.Context normProcs : Array Simprocs anchorRefs? : Option (Array AnchorRef) := none -- TODO: inductives to split -def mkParams (config : Grind.Config) : MetaM Params := do +def mkParams (config : Grind.Config) (extensions : ExtensionStateArray) : MetaM Params := do let norm ← Grind.getSimpContext config let normProcs ← Grind.getSimprocs let symPrios ← getGlobalSymbolPriorities - return { config, norm, normProcs, symPrios } + return { config, norm, normProcs, symPrios, extensions } + +def mkDefaultParams (config : Grind.Config) : MetaM Params := do + mkParams config #[← getDefaultExtensionState] + +def mkOnlyParams (config : Grind.Config) : MetaM Params := do + mkParams config #[← getOnlyExtensionState] def mkMethods (evalTactic? : Option EvalTactic := none) : CoreM Methods := do let builtinPropagators ← builtinPropagatorsRef.get @@ -99,9 +124,11 @@ def GrindM.run (x : GrindM α) (params : Params) (evalTactic? : Option EvalTacti let simp := params.norm let config := params.config let symPrios := params.symPrios + let extensions := params.extensions let anchorRefs? := params.anchorRefs? - let funCCs := params.funCCs - x (← mkMethods evalTactic?).toMethodsRef { config, anchorRefs?, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr, symPrios, funCCs } + x (← mkMethods evalTactic?).toMethodsRef + { config, anchorRefs?, simpMethods, simp, extensions, symPrios + trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr } |>.run' { scState } private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do @@ -136,11 +163,11 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do let bfalseExpr ← getBoolFalseExpr let natZeroExpr ← getNatZeroExpr let ordEqExpr ← getOrderingEqExpr - let thmMap := params.ematch - let casesTypes := params.casesTypes + let thmEMatch := params.extensions.map fun ext => ext.ematch + let thmInj := params.extensions.map fun ext => ext.inj let clean ← mkCleanState mvarId params let sstates ← Solvers.mkInitialStates - GoalM.run' { mvarId, ematch.thmMap := thmMap, inj.thms := params.inj, split.casesTypes := casesTypes, clean, sstates } do + GoalM.run' { mvarId, ematch.thmMap := thmEMatch, inj.thms := thmInj, clean, sstates } do initENodeCore falseExpr (interpreted := true) (ctor := false) initENodeCore trueExpr (interpreted := true) (ctor := false) initENodeCore btrueExpr (interpreted := false) (ctor := true) diff --git a/src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean b/src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean index 9652fe6de7..be03852578 100644 --- a/src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean +++ b/src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean @@ -52,6 +52,9 @@ private def addBuiltin (propagatorName : Name) (stx : Syntax) : AttrM Unit := do declareBuiltin initDeclName val go.run' {} +/- +**Note**: We currently use the same propagators for all `grind` attributes. +-/ builtin_initialize registerBuiltinAttribute { ref := by exact decl_name% diff --git a/src/Lean/Meta/Tactic/Grind/RegisterCommand.lean b/src/Lean/Meta/Tactic/Grind/RegisterCommand.lean new file mode 100644 index 0000000000..2f5dfc6ea9 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/RegisterCommand.lean @@ -0,0 +1,30 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.Types +meta import Lean.Meta.Tactic.Grind.Attr +public section +namespace Lean.Meta.Grind + +macro (name := _root_.Lean.Parser.Command.registerGrindAttr) doc:(docComment)? + "register_grind_attr" id:ident : command => do + let str1 := id.getId.toString + let idParser1 := mkIdentFrom id (`Lean.Parser.Attr ++ id.getId) + let str2 := id.getId.toString ++ "!" + let idParser2 := mkIdentFrom id (`Lean.Parser.Attr ++ (id.getId.appendAfter "!")) + let str3 := id.getId.toString ++ "?" + let idParser3 := mkIdentFrom id (`Lean.Parser.Attr ++ (id.getId.appendAfter "?")) + let str4 := id.getId.toString ++ "!?" + let idParser4 := mkIdentFrom id (`Lean.Parser.Attr ++ (id.getId.appendAfter "!?")) + `($[$doc:docComment]? initialize ext : Extension ← registerAttr $(quote id.getId) (ref := $(quote id.getId)) + $[$doc:docComment]? syntax (name := $idParser1:ident) $(quote str1):str (ppSpace Lean.Parser.Attr.grindMod)? : attr + $[$doc:docComment]? syntax (name := $idParser2:ident) $(quote str2):str (ppSpace Lean.Parser.Attr.grindMod)? : attr + $[$doc:docComment]? syntax (name := $idParser3:ident) $(quote str3):str (ppSpace Lean.Parser.Attr.grindMod)? : attr + $[$doc:docComment]? syntax (name := $idParser4:ident) $(quote str4):str (ppSpace Lean.Parser.Attr.grindMod)? : attr + ) + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/SimpUtil.lean b/src/Lean/Meta/Tactic/Grind/SimpUtil.lean index bc654e3886..43c89e3812 100644 --- a/src/Lean/Meta/Tactic/Grind/SimpUtil.lean +++ b/src/Lean/Meta/Tactic/Grind/SimpUtil.lean @@ -18,6 +18,9 @@ import Init.Grind.Norm public section namespace Lean.Meta.Grind +/- +TODO: group into a `grind` extension object +-/ builtin_initialize normExt : SimpExtension ← mkSimpExt def registerNormTheorems (preDeclNames : Array Name) (postDeclNames : Array Name) : MetaM Unit := do @@ -176,14 +179,18 @@ private def addDeclToUnfold (s : SimpTheorems) (declName : Name) : MetaM SimpThe else return s -/-- Returns the simplification context used by `grind`. -/ -protected def getSimpContext (config : Grind.Config) : MetaM Simp.Context := do +def getNormTheorems : MetaM SimpTheorems := do let mut thms ← normExt.getTheorems thms ← addDeclToUnfold thms ``GE.ge thms ← addDeclToUnfold thms ``GT.gt thms ← addDeclToUnfold thms ``Nat.cast thms ← addDeclToUnfold thms ``Bool.xor thms ← addDeclToUnfold thms ``Ne + return thms + +/-- Returns the simplification context used by `grind`. -/ +protected def getSimpContext (config : Grind.Config) : MetaM Simp.Context := do + let thms ← getNormTheorems Simp.mkContext (config := { arith := true diff --git a/src/Lean/Meta/Tactic/Grind/Theorems.lean b/src/Lean/Meta/Tactic/Grind/Theorems.lean index 7da122fcde..3b0ed3b8ad 100644 --- a/src/Lean/Meta/Tactic/Grind/Theorems.lean +++ b/src/Lean/Meta/Tactic/Grind/Theorems.lean @@ -174,4 +174,36 @@ def getProofForDecl (declName : Name) : MetaM Expr := do let us := info.levelParams.map mkLevelParam return mkConst declName us +/-- +A `TheoremsArray α` is a collection of `Theorems α`. +The array is scanned linear during theorem activation. +This avoids the need for efficiently merging the `Theorems α` data structure. +-/ +abbrev TheoremsArray (α : Type) := Array (Theorems α) + +@[specialize] +def TheoremsArray.retrieve? (s : TheoremsArray α) (sym : Name) : Option (List α × TheoremsArray α) := Id.run do + for h : i in *...s.size do + if let some (thms, a) ← s[i].retrieve? sym then + return some (thms, s.set i a) + return none + +def TheoremsArray.insert [TheoremLike α] (s : TheoremsArray α) (thm : α) : TheoremsArray α := Id.run do + if s.isEmpty then + let thms := { : Theorems α} + #[thms.insert thm] + else + s.modify 0 (·.insert thm) + +def TheoremsArray.isErased (s : TheoremsArray α) (origin : Origin) : Bool := + s.any fun thms => thms.erased.contains origin + +def TheoremsArray.find (s : TheoremsArray α) (origin : Origin) : List α := Id.run do + let mut r := [] + for h : i in *...s.size do + let thms := s[i].find origin + unless thms.isEmpty do + r := r ++ thms + return r + end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 9f103ac68d..e3800a74b7 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -9,6 +9,7 @@ public import Lean.Meta.Tactic.Simp.Types public import Lean.Meta.Tactic.Grind.AlphaShareCommon public import Lean.Meta.Tactic.Grind.Attr public import Lean.Meta.Tactic.Grind.CheckResult +public import Lean.Meta.Tactic.Grind.Extension public import Init.Data.Queue import Lean.Meta.Tactic.Grind.ExprPtr import Lean.HeadIndex @@ -158,8 +159,7 @@ structure Context where splitSource : SplitSource := .input /-- Symbol priorities for inferring E-matching patterns -/ symPrios : SymbolPriorities - /-- Global declarations marked with `@[grind funCC]` -/ - funCCs : NameSet + extensions : ExtensionStateArray := #[] trueExpr : Expr falseExpr : Expr natZExpr : Expr @@ -346,6 +346,18 @@ def reportMVarInternalization : GrindM Bool := def getSymbolPriorities : GrindM SymbolPriorities := do return (← readThe Context).symPrios +/-- +Returns `true` if we `declName` is tagged with `funCC` modifier. +-/ +def hasFunCCModifier (declName : Name) : GrindM Bool := + return (← readThe Context).extensions.any fun ext => ext.funCC.contains declName + +def isSplit (declName : Name) : GrindM Bool := + return (← readThe Context).extensions.any fun ext => ext.casesTypes.isSplit declName + +def isEagerSplit (declName : Name) : GrindM Bool := + return (← readThe Context).extensions.any fun ext => ext.casesTypes.isEagerSplit declName + /-- Returns `true` if `declName` is the name of a `match` equation or a `match` congruence equation. -/ @@ -758,7 +770,7 @@ structure EMatch.State where Inactive global theorems. As we internalize terms, we activate theorems as we find their symbols. Local theorem provided by users are added directly into `newThms`. -/ - thmMap : EMatchTheorems + thmMap : EMatchTheoremsArray /-- Goal modification time. -/ gmt : Nat := 0 /-- Active theorems that we have performed ematching at least once. -/ @@ -840,8 +852,6 @@ structure SplitArg where structure Split.State where /-- Number of splits performed to get to this goal. -/ num : Nat := 0 - /-- Inductive datatypes marked for case-splitting -/ - casesTypes : CasesTypes := {} /-- Case-split candidates. -/ candidates : List SplitInfo := [] /-- Case-splits that have been inserted at `candidates` at some point. -/ @@ -901,7 +911,7 @@ structure InjectiveInfo where /-- State for injective theorem support. -/ structure Injective.State where - thms : InjectiveTheorems + thms : InjectiveTheoremsArray fns : PHashMap ExprPtr InjectiveInfo := {} deriving Inhabited diff --git a/src/Lean/Meta/Tactic/LibrarySearch.lean b/src/Lean/Meta/Tactic/LibrarySearch.lean index 9821026303..c94f715b57 100644 --- a/src/Lean/Meta/Tactic/LibrarySearch.lean +++ b/src/Lean/Meta/Tactic/LibrarySearch.lean @@ -55,7 +55,7 @@ def grindDischarger (mvarId : MVarId) : MetaM (Option (List MVarId)) := do let [subgoal] ← mvarId.apply markerExpr | return none -- Solve the subgoal with grind - let params ← Grind.mkParams {} + let params ← Grind.mkDefaultParams {} let result ← Grind.main subgoal params if result.hasFailed then return none diff --git a/src/Lean/Meta/Tactic/Try/Collect.lean b/src/Lean/Meta/Tactic/Try/Collect.lean index 481e7d998d..852d8aaea7 100644 --- a/src/Lean/Meta/Tactic/Try/Collect.lean +++ b/src/Lean/Meta/Tactic/Try/Collect.lean @@ -126,7 +126,7 @@ def checkInductive (localDecl : LocalDecl) : M Unit := do let .const declName _ := type.getAppFn | return () let .inductInfo val ← getConstInfo declName | return () if (← isEligible declName) then - unless (← Grind.isSplit declName) do + unless (← Grind.isGlobalSplit declName) do modify fun s => { s with indCandidates := s.indCandidates.push { fvarId := localDecl.fvarId, val } } unsafe abbrev Cache := PtrSet Expr diff --git a/stage0/src/stdlib_flags.h b/stage0/src/stdlib_flags.h index 79a0e58edd..ad491b0de1 100644 --- a/stage0/src/stdlib_flags.h +++ b/stage0/src/stdlib_flags.h @@ -1,3 +1,4 @@ +// update me! #include "util/options.h" namespace lean { diff --git a/tests/lean/run/sharecommon_mpz.lean b/tests/lean/run/sharecommon_mpz.lean index dc692ed3c8..fe2290dfcd 100644 --- a/tests/lean/run/sharecommon_mpz.lean +++ b/tests/lean/run/sharecommon_mpz.lean @@ -3,7 +3,7 @@ import Lean open Lean Meta Tactic Grind def runGrind (x : GrindM α) : MetaM α := do - GrindM.run x (← mkParams {}) + GrindM.run x (← mkDefaultParams {}) @[noinline] def mkA (x : Nat) := x + 1 diff --git a/tests/pkg/user_attr/UserAttr/BlaAttr.lean b/tests/pkg/user_attr/UserAttr/BlaAttr.lean index 5373e6d96d..c67826b538 100644 --- a/tests/pkg/user_attr/UserAttr/BlaAttr.lean +++ b/tests/pkg/user_attr/UserAttr/BlaAttr.lean @@ -31,3 +31,5 @@ initialize registerBuiltinAttribute { logInfo m!"trace_add attribute added to {decl}" -- applicationTime := .afterCompilation } + +register_grind_attr my_grind diff --git a/tests/pkg/user_attr/UserAttr/Tst.lean b/tests/pkg/user_attr/UserAttr/Tst.lean index 9b3862f403..90d66b969a 100644 --- a/tests/pkg/user_attr/UserAttr/Tst.lean +++ b/tests/pkg/user_attr/UserAttr/Tst.lean @@ -155,3 +155,22 @@ termination_by n => n end end TraceAdd + +namespace GrindAttr + +opaque f : Nat → Nat +opaque g : Nat → Nat + +@[my_grind] theorem fax : f (f x) = f x := sorry + +@[my_grind =] theorem fax2 : f (f (f x)) = f x := by + fail_if_success grind + grind [my_grind] + +@[my_grind? .] theorem fg : g (f x) = x := sorry + +@[my_grind? =] theorem fax3 : f (f (f x)) = f x := sorry + +@[my_grind!? .] theorem fax4 : f (f (f x)) = f x := sorry + +end GrindAttr