diff --git a/src/Init/Grind/Tactics.lean b/src/Init/Grind/Tactics.lean index 342abc744e..63173f8f20 100644 --- a/src/Init/Grind/Tactics.lean +++ b/src/Init/Grind/Tactics.lean @@ -25,7 +25,8 @@ syntax grindUsr := &"usr " syntax grindCases := &"cases " syntax grindCasesEager := atomic(&"cases" &"eager ") syntax grindIntro := &"intro " -syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro +syntax grindExt := &"ext " +syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt syntax (name := grind) "grind" (grindMod)? : attr end Attr end Lean.Parser @@ -68,8 +69,10 @@ structure Config where failures : Nat := 1 /-- Maximum number of heartbeats (in thousands) the canonicalizer can spend per definitional equality test. -/ canonHeartbeats : Nat := 1000 - /-- If `ext` is `true`, `grind` uses extensionality theorems available in the environment. -/ + /-- If `ext` is `true`, `grind` uses extensionality theorems that have been marked with `[grind ext]`. -/ ext : Bool := true + /-- If `extAll` is `true`, `grind` uses any extensionality theorems available in the environment. -/ + extAll : Bool := false /-- If `funext` is `true`, `grind` creates new opportunities for applying function extensionality by case-splitting on equalities between lambda expressions. diff --git a/src/Lean/Elab/Tactic/Grind.lean b/src/Lean/Elab/Tactic/Grind.lean index d1d01f7d99..6ed9c8db11 100644 --- a/src/Lean/Elab/Tactic/Grind.lean +++ b/src/Lean/Elab/Tactic/Grind.lean @@ -89,6 +89,8 @@ def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic. params ← withRef p <| addEMatchTheorem params ctor .default else throwError "invalid use of `intro` modifier, `{declName}` is not an inductive predicate" + | .ext => + throwError "`[grind ext]` cannot be set using parameters" | .infer => if let some declName ← Grind.isCasesAttrCandidate? declName false then params := { params with casesTypes := params.casesTypes.insert declName false } diff --git a/src/Lean/Meta/Tactic/Ext.lean b/src/Lean/Meta/Tactic/Ext.lean index bb6ce28b0b..050b8d68e4 100644 --- a/src/Lean/Meta/Tactic/Ext.lean +++ b/src/Lean/Meta/Tactic/Ext.lean @@ -62,6 +62,15 @@ This is triggered by `attribute [-ext] name`. def ExtTheorems.eraseCore (d : ExtTheorems) (declName : Name) : ExtTheorems := { d with erased := d.erased.insert declName } +/-- Returns `true` if `d` contains theorem with name `declName`. -/ +def ExtTheorems.contains (d : ExtTheorems) (declName : Name) : Bool := + d.tree.containsValueP (·.declName == declName) && !d.erased.contains declName + +/-- Returns `true` if `declName` is tagged with `[ext]` attribute. -/ +def isExtTheorem (declName : Name) : CoreM Bool := do + let extTheorems := extExtension.getState (← getEnv) + return extTheorems.contains declName + /-- Erases a name marked as a `ext` attribute. Check that it does in fact have the `ext` attribute by making sure it names a `ExtTheorem` @@ -69,7 +78,7 @@ found somewhere in the state's tree, and is not erased. -/ def ExtTheorems.erase [Monad m] [MonadError m] (d : ExtTheorems) (declName : Name) : m ExtTheorems := do - unless d.tree.containsValueP (·.declName == declName) && !d.erased.contains declName do + unless d.contains declName do throwError "'{declName}' does not have [ext] attribute" return d.eraseCore declName diff --git a/src/Lean/Meta/Tactic/Grind/Attr.lean b/src/Lean/Meta/Tactic/Grind/Attr.lean index c2d9d65088..2388c40950 100644 --- a/src/Lean/Meta/Tactic/Grind/Attr.lean +++ b/src/Lean/Meta/Tactic/Grind/Attr.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Lean.Meta.Tactic.Grind.EMatchTheorem import Lean.Meta.Tactic.Grind.Cases +import Lean.Meta.Tactic.Grind.ExtAttr namespace Lean.Meta.Grind @@ -14,6 +15,7 @@ inductive AttrKind where | cases (eager : Bool) | intro | infer + | ext /-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do @@ -34,6 +36,7 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do | `(Parser.Attr.grindMod| cases) => return .cases false | `(Parser.Attr.grindMod| cases eager) => return .cases true | `(Parser.Attr.grindMod| intro) => return .intro + | `(Parser.Attr.grindMod| ext) => return .ext | _ => throwError "unexpected `grind` theorem kind: `{stx}`" /-- Return theorem kind for `stx` of the form `(Attr.grindMod)?` -/ @@ -78,6 +81,7 @@ builtin_initialize addEMatchAttr ctor attrKind .default else throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate" + | .ext => addExtAttr declName attrKind | .infer => if let some declName ← isCasesAttrCandidate? declName false then addCasesAttr declName false attrKind @@ -91,6 +95,8 @@ builtin_initialize erase := fun declName => MetaM.run' do if (← isCasesAttrCandidate declName false) then eraseCasesAttr declName + else if (← isExtTheorem declName) then + eraseExtAttr declName else eraseEMatchAttr declName } diff --git a/src/Lean/Meta/Tactic/Grind/ExtAttr.lean b/src/Lean/Meta/Tactic/Grind/ExtAttr.lean new file mode 100644 index 0000000000..7d763946bd --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/ExtAttr.lean @@ -0,0 +1,43 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Ext + +namespace Lean.Meta.Grind +/-! Grind extensionality attribute to mark which `[ext]` theorems should be used. -/ + +/-- Extensionality theorems that can be used by `grind` -/ +abbrev ExtTheorems := PHashSet Name + +builtin_initialize extTheoremsExt : SimpleScopedEnvExtension Name ExtTheorems ← + registerSimpleScopedEnvExtension { + initial := {} + addEntry := fun s declName => s.insert declName + } + +def validateExtAttr (declName : Name) : CoreM Unit := do + unless (← Ext.isExtTheorem declName) do + throwError "invalid `[grind ext]`, `{declName}` is tagged with `[ext]`" + +def addExtAttr (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do + validateExtAttr declName + extTheoremsExt.add declName attrKind + +private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do + if s.contains declName then + return s.erase declName + else + throwError "`{declName}` is not marked with the `[grind ext]` attribute" + +def eraseExtAttr (declName : Name) : CoreM Unit := do + let s := extTheoremsExt.getState (← getEnv) + let s ← eraseDecl s declName + modifyEnv fun env => extTheoremsExt.modifyState env fun _ => s + +def isExtTheorem (declName : Name) : CoreM Bool := do + return extTheoremsExt.getState (← getEnv) |>.contains declName + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 056c1650cd..b59a281746 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -17,6 +17,7 @@ import Lean.Meta.Tactic.Util import Lean.Meta.Tactic.Ext import Lean.Meta.Tactic.Grind.ENodeKey import Lean.Meta.Tactic.Grind.Attr +import Lean.Meta.Tactic.Grind.ExtAttr import Lean.Meta.Tactic.Grind.Cases import Lean.Meta.Tactic.Grind.Arith.Types import Lean.Meta.Tactic.Grind.EMatchTheorem @@ -572,7 +573,7 @@ structure Goal where -/ appMap : PHashMap HeadIndex (List Expr) := {} /-- Equations and propositions to be processed. -/ - newFacts : Array NewFact := #[] + newFacts : Array NewFact := #[] /-- `inconsistent := true` if `ENode`s for `True` and `False` are in the same equivalence class. -/ inconsistent : Bool := false /-- Next unique index for creating ENodes -/ @@ -580,17 +581,17 @@ structure Goal where /-- new facts to be preprocessed and then asserted. -/ newRawFacts : Std.Queue NewRawFact := ∅ /-- Asserted facts -/ - facts : PArray Expr := {} + facts : PArray Expr := {} /-- Cached extensionality theorems for types. -/ - extThms : PHashMap ENodeKey (Array Ext.ExtTheorem) := {} + extThms : PHashMap ENodeKey (Array Ext.ExtTheorem) := {} /-- State of the E-matching module. -/ - ematch : EMatch.State + ematch : EMatch.State /-- State of the case-splitting module. -/ - split : Split.State := {} + split : Split.State := {} /-- State of arithmetic procedures. -/ - arith : Arith.State := {} + arith : Arith.State := {} /-- State of the clean name generator. -/ - clean : Clean.State := {} + clean : Clean.State := {} deriving Inhabited def Goal.admit (goal : Goal) : MetaM Unit := @@ -1260,11 +1261,15 @@ Returns extensionality theorems for the given type if available. If `Config.ext` is `false`, the result is `#[]`. -/ def getExtTheorems (type : Expr) : GoalM (Array Ext.ExtTheorem) := do - unless (← getConfig).ext do return #[] + unless (← getConfig).ext || (← getConfig).extAll do return #[] if let some thms := (← get).extThms.find? { expr := type } then return thms else let thms ← Ext.getExtTheorems type + let thms ← if (← getConfig).extAll then + pure thms + else + thms.filterM fun thm => isExtTheorem thm.declName modify fun s => { s with extThms := s.extThms.insert { expr := type } thms } return thms diff --git a/tests/lean/run/grind_bintree.lean b/tests/lean/run/grind_bintree.lean index 5a68261a2a..00d447362e 100644 --- a/tests/lean/run/grind_bintree.lean +++ b/tests/lean/run/grind_bintree.lean @@ -62,6 +62,8 @@ where : toListTR.go t acc = t.toList ++ acc := by induction t generalizing acc <;> grind [toListTR.go, toList] +attribute [grind ext] funext -- TODO: remove after update-stage0 + @[csimp] theorem Tree.toList_eq_toListTR_csimp : @Tree.toList = @Tree.toListTR := by grind [toList_eq_toListTR] diff --git a/tests/lean/run/grind_cat.lean b/tests/lean/run/grind_cat.lean index 3b4dbf3fc8..9e1422f5c5 100644 --- a/tests/lean/run/grind_cat.lean +++ b/tests/lean/run/grind_cat.lean @@ -67,6 +67,9 @@ structure NatTrans [Category.{v₁, u₁} C] [Category.{v₂, u₂} D] (F G : Fu /-- The naturality square for a given morphism. -/ naturality : ∀ ⦃X Y : C⦄ (f : X ⟶ Y), F.map f ≫ app Y = app X ≫ G.map f := by grind +attribute [grind ext] NatTrans.ext -- TODO: remove after builtin extensionality for structures +attribute [grind ext] funext -- TODO: remove after update-stage0 + attribute [simp, grind =] NatTrans.naturality namespace NatTrans @@ -103,6 +106,8 @@ namespace NatTrans @[ext] theorem ext' {α β : F ⟶ G} (w : α.app = β.app) : α = β := NatTrans.ext w +attribute [grind ext] ext' + @[simp, grind =] theorem id_app (F : Functor C D) (X : C) : (𝟙 F : F ⟶ F).app X = 𝟙 (F.obj X) := rfl @@ -168,7 +173,7 @@ variable {C : Type u} [Category.{v} C] {X Y Z : C} namespace Iso -@[ext] +@[ext, grind ext] theorem ext ⦃α β : X ≅ Y⦄ (w : α.hom = β.hom) : α = β := suffices α.inv = β.inv by grind [Iso] calc diff --git a/tests/lean/run/grind_cat2.lean b/tests/lean/run/grind_cat2.lean index 423052db71..29847354e2 100644 --- a/tests/lean/run/grind_cat2.lean +++ b/tests/lean/run/grind_cat2.lean @@ -68,6 +68,9 @@ structure NatTrans [Category.{v₁, u₁} C] [Category.{v₂, u₂} D] (F G : Fu /-- The naturality square for a given morphism. -/ naturality : ∀ ⦃X Y : C⦄ (f : X ⟶ Y), F.map f ≫ app Y = app X ≫ G.map f := by grind +attribute [grind ext] funext -- TODO: remove +attribute [grind ext] NatTrans.ext + attribute [simp, grind =] NatTrans.naturality namespace NatTrans diff --git a/tests/lean/run/grind_funext.lean b/tests/lean/run/grind_funext.lean index d0a8a0f4fa..bb8de6af85 100644 --- a/tests/lean/run/grind_funext.lean +++ b/tests/lean/run/grind_funext.lean @@ -1,5 +1,6 @@ set_option grind.warning false +attribute [grind ext] funext -- TODO: remove example (f : (Nat → Nat) → Nat → Nat → Nat) : a = b → f (fun x => a + x) 1 b = f (fun x => b + x) 1 a := by grind diff --git a/tests/lean/run/grind_ite.lean b/tests/lean/run/grind_ite.lean index cd81a6aede..2bb073af47 100644 --- a/tests/lean/run/grind_ite.lean +++ b/tests/lean/run/grind_ite.lean @@ -117,7 +117,7 @@ set_option grind.warning false -- We first set up some convenient macros for dealing with subtypes using `grind`. /-- Construct a term of a subtype, using `grind` to discharge the condition. -/ -macro "g⟨" a:term "⟩" : term => `(⟨$a, by grind (gen := 9) (splits := 9)⟩) +macro "g⟨" a:term "⟩" : term => `(⟨$a, by grind (gen := 8) (splits := 9)⟩) /-- Replace a term of a subtype with a term of a different subtype, using the same data, and using `grind` to discharge the new condition (with access to the old condition). @@ -159,6 +159,8 @@ we are allowed to increase the size of the branches by one, and still be smaller | var _ => 1 | .ite i t e => 2 * normSize i + max (normSize t) (normSize e) + 1 +attribute [grind ext] funext -- TODO remove +set_option profiler true /-- Normalizes the expression at the same time as making the variable assignments to literal booleans given by `assign`. diff --git a/tests/lean/run/grind_list.lean b/tests/lean/run/grind_list.lean index 3d0e9d4ab2..d10ecfba81 100644 --- a/tests/lean/run/grind_list.lean +++ b/tests/lean/run/grind_list.lean @@ -1,4 +1,5 @@ reset_grind_attrs% +set_option grind.warning false namespace List @@ -18,6 +19,8 @@ theorem getElem!_of_getElem?' [Inhabited α] : ∀ {l : List α} {i : Nat}, l[i]? = some b → l[i]! = b := by grind +attribute [grind ext] List.ext_getElem? + attribute [local grind =] Option.map_some Option.map_none in attribute [local grind =] getElem?_map in attribute [local grind =] getElem?_replicate in diff --git a/tests/lean/run/grind_t1.lean b/tests/lean/run/grind_t1.lean index 03c178b903..98332bb6dd 100644 --- a/tests/lean/run/grind_t1.lean +++ b/tests/lean/run/grind_t1.lean @@ -317,6 +317,8 @@ example {α} (f : α → Type) (a : α) (h : ∀ x, Nonempty (f x)) : Nonempty ( example {α β} (f : α → β) (a : α) : ∃ a', f a' = f a := by grind +attribute [grind ext] List.ext_getElem? + open List in example : (replicate n a).map f = replicate n (f a) := by grind +splitIndPred only [Option.map_some, Option.map_none, getElem?_map, getElem?_replicate] @@ -339,6 +341,8 @@ example : (replicate n a).map f = replicate n (f a) := by a : Nat b : Bool +attribute [grind ext] S.ext + example (x y : S) : x.a = y.a → y.b = x.b → x = y := by grind diff --git a/tests/lean/run/grind_trace.lean b/tests/lean/run/grind_trace.lean index 4c939adb79..c3caadda37 100644 --- a/tests/lean/run/grind_trace.lean +++ b/tests/lean/run/grind_trace.lean @@ -32,6 +32,7 @@ info: Try this: grind only [= List.length_cons] example : 0 < (x :: t).length := by grind? +attribute [grind ext] List.ext_getElem? /-- info: Try this: grind only [= Option.map_some, = Option.map_none, = List.getElem?_replicate, = List.getElem?_eq_some_iff, = List.getElem?_map, = List.getElem_replicate, = List.getElem?_eq_none, = List.length_replicate, →