feat: add [grind ext] attribute (#7949)

This PR adds the attribute `[grind ext]`. It is used to select which
`[ext]` theorems should be used by `grind`. The option `grind +extAll`
instructs `grind` to use all `[ext]` theorems available in the
environment.
After update stage0, we need to add the builtin `[grind ext]`
annotations to key theorems such as `funext`.
This commit is contained in:
Leonardo de Moura 2025-04-13 15:08:36 -07:00 committed by GitHub
parent 2337b95676
commit cd5b495573
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 102 additions and 13 deletions

View file

@ -25,7 +25,8 @@ syntax grindUsr := &"usr "
syntax grindCases := &"cases "
syntax grindCasesEager := atomic(&"cases" &"eager ")
syntax grindIntro := &"intro "
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro
syntax grindExt := &"ext "
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt
syntax (name := grind) "grind" (grindMod)? : attr
end Attr
end Lean.Parser
@ -68,8 +69,10 @@ structure Config where
failures : Nat := 1
/-- Maximum number of heartbeats (in thousands) the canonicalizer can spend per definitional equality test. -/
canonHeartbeats : Nat := 1000
/-- If `ext` is `true`, `grind` uses extensionality theorems available in the environment. -/
/-- If `ext` is `true`, `grind` uses extensionality theorems that have been marked with `[grind ext]`. -/
ext : Bool := true
/-- If `extAll` is `true`, `grind` uses any extensionality theorems available in the environment. -/
extAll : Bool := false
/--
If `funext` is `true`, `grind` creates new opportunities for applying function extensionality by case-splitting
on equalities between lambda expressions.

View file

@ -89,6 +89,8 @@ def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.
params ← withRef p <| addEMatchTheorem params ctor .default
else
throwError "invalid use of `intro` modifier, `{declName}` is not an inductive predicate"
| .ext =>
throwError "`[grind ext]` cannot be set using parameters"
| .infer =>
if let some declName ← Grind.isCasesAttrCandidate? declName false then
params := { params with casesTypes := params.casesTypes.insert declName false }

View file

@ -62,6 +62,15 @@ This is triggered by `attribute [-ext] name`.
def ExtTheorems.eraseCore (d : ExtTheorems) (declName : Name) : ExtTheorems :=
{ d with erased := d.erased.insert declName }
/-- Returns `true` if `d` contains theorem with name `declName`. -/
def ExtTheorems.contains (d : ExtTheorems) (declName : Name) : Bool :=
d.tree.containsValueP (·.declName == declName) && !d.erased.contains declName
/-- Returns `true` if `declName` is tagged with `[ext]` attribute. -/
def isExtTheorem (declName : Name) : CoreM Bool := do
let extTheorems := extExtension.getState (← getEnv)
return extTheorems.contains declName
/--
Erases a name marked as a `ext` attribute.
Check that it does in fact have the `ext` attribute by making sure it names a `ExtTheorem`
@ -69,7 +78,7 @@ found somewhere in the state's tree, and is not erased.
-/
def ExtTheorems.erase [Monad m] [MonadError m] (d : ExtTheorems) (declName : Name) :
m ExtTheorems := do
unless d.tree.containsValueP (·.declName == declName) && !d.erased.contains declName do
unless d.contains declName do
throwError "'{declName}' does not have [ext] attribute"
return d.eraseCore declName

View file

@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Tactic.Grind.EMatchTheorem
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.ExtAttr
namespace Lean.Meta.Grind
@ -14,6 +15,7 @@ inductive AttrKind where
| cases (eager : Bool)
| intro
| infer
| ext
/-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/
def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
@ -34,6 +36,7 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
| `(Parser.Attr.grindMod| cases) => return .cases false
| `(Parser.Attr.grindMod| cases eager) => return .cases true
| `(Parser.Attr.grindMod| intro) => return .intro
| `(Parser.Attr.grindMod| ext) => return .ext
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"
/-- Return theorem kind for `stx` of the form `(Attr.grindMod)?` -/
@ -78,6 +81,7 @@ builtin_initialize
addEMatchAttr ctor attrKind .default
else
throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate"
| .ext => addExtAttr declName attrKind
| .infer =>
if let some declName ← isCasesAttrCandidate? declName false then
addCasesAttr declName false attrKind
@ -91,6 +95,8 @@ builtin_initialize
erase := fun declName => MetaM.run' do
if (← isCasesAttrCandidate declName false) then
eraseCasesAttr declName
else if (← isExtTheorem declName) then
eraseExtAttr declName
else
eraseEMatchAttr declName
}

View file

@ -0,0 +1,43 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Ext
namespace Lean.Meta.Grind
/-! Grind extensionality attribute to mark which `[ext]` theorems should be used. -/
/-- Extensionality theorems that can be used by `grind` -/
abbrev ExtTheorems := PHashSet Name
builtin_initialize extTheoremsExt : SimpleScopedEnvExtension Name ExtTheorems ←
registerSimpleScopedEnvExtension {
initial := {}
addEntry := fun s declName => s.insert declName
}
def validateExtAttr (declName : Name) : CoreM Unit := do
unless (← Ext.isExtTheorem declName) do
throwError "invalid `[grind ext]`, `{declName}` is tagged with `[ext]`"
def addExtAttr (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
validateExtAttr declName
extTheoremsExt.add declName attrKind
private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do
if s.contains declName then
return s.erase declName
else
throwError "`{declName}` is not marked with the `[grind ext]` attribute"
def eraseExtAttr (declName : Name) : CoreM Unit := do
let s := extTheoremsExt.getState (← getEnv)
let s ← eraseDecl s declName
modifyEnv fun env => extTheoremsExt.modifyState env fun _ => s
def isExtTheorem (declName : Name) : CoreM Bool := do
return extTheoremsExt.getState (← getEnv) |>.contains declName
end Lean.Meta.Grind

View file

@ -17,6 +17,7 @@ import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Ext
import Lean.Meta.Tactic.Grind.ENodeKey
import Lean.Meta.Tactic.Grind.Attr
import Lean.Meta.Tactic.Grind.ExtAttr
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.Arith.Types
import Lean.Meta.Tactic.Grind.EMatchTheorem
@ -572,7 +573,7 @@ structure Goal where
-/
appMap : PHashMap HeadIndex (List Expr) := {}
/-- Equations and propositions to be processed. -/
newFacts : Array NewFact := #[]
newFacts : Array NewFact := #[]
/-- `inconsistent := true` if `ENode`s for `True` and `False` are in the same equivalence class. -/
inconsistent : Bool := false
/-- Next unique index for creating ENodes -/
@ -580,17 +581,17 @@ structure Goal where
/-- new facts to be preprocessed and then asserted. -/
newRawFacts : Std.Queue NewRawFact := ∅
/-- Asserted facts -/
facts : PArray Expr := {}
facts : PArray Expr := {}
/-- Cached extensionality theorems for types. -/
extThms : PHashMap ENodeKey (Array Ext.ExtTheorem) := {}
extThms : PHashMap ENodeKey (Array Ext.ExtTheorem) := {}
/-- State of the E-matching module. -/
ematch : EMatch.State
ematch : EMatch.State
/-- State of the case-splitting module. -/
split : Split.State := {}
split : Split.State := {}
/-- State of arithmetic procedures. -/
arith : Arith.State := {}
arith : Arith.State := {}
/-- State of the clean name generator. -/
clean : Clean.State := {}
clean : Clean.State := {}
deriving Inhabited
def Goal.admit (goal : Goal) : MetaM Unit :=
@ -1260,11 +1261,15 @@ Returns extensionality theorems for the given type if available.
If `Config.ext` is `false`, the result is `#[]`.
-/
def getExtTheorems (type : Expr) : GoalM (Array Ext.ExtTheorem) := do
unless (← getConfig).ext do return #[]
unless (← getConfig).ext || (← getConfig).extAll do return #[]
if let some thms := (← get).extThms.find? { expr := type } then
return thms
else
let thms ← Ext.getExtTheorems type
let thms ← if (← getConfig).extAll then
pure thms
else
thms.filterM fun thm => isExtTheorem thm.declName
modify fun s => { s with extThms := s.extThms.insert { expr := type } thms }
return thms

View file

@ -62,6 +62,8 @@ where
: toListTR.go t acc = t.toList ++ acc := by
induction t generalizing acc <;> grind [toListTR.go, toList]
attribute [grind ext] funext -- TODO: remove after update-stage0
@[csimp] theorem Tree.toList_eq_toListTR_csimp
: @Tree.toList = @Tree.toListTR := by
grind [toList_eq_toListTR]

View file

@ -67,6 +67,9 @@ structure NatTrans [Category.{v₁, u₁} C] [Category.{v₂, u₂} D] (F G : Fu
/-- The naturality square for a given morphism. -/
naturality : ∀ ⦃X Y : C⦄ (f : X ⟶ Y), F.map f ≫ app Y = app X ≫ G.map f := by grind
attribute [grind ext] NatTrans.ext -- TODO: remove after builtin extensionality for structures
attribute [grind ext] funext -- TODO: remove after update-stage0
attribute [simp, grind =] NatTrans.naturality
namespace NatTrans
@ -103,6 +106,8 @@ namespace NatTrans
@[ext]
theorem ext' {α β : F ⟶ G} (w : α.app = β.app) : α = β := NatTrans.ext w
attribute [grind ext] ext'
@[simp, grind =]
theorem id_app (F : Functor C D) (X : C) : (𝟙 F : F ⟶ F).app X = 𝟙 (F.obj X) := rfl
@ -168,7 +173,7 @@ variable {C : Type u} [Category.{v} C] {X Y Z : C}
namespace Iso
@[ext]
@[ext, grind ext]
theorem ext ⦃α β : X ≅ Y⦄ (w : α.hom = β.hom) : α = β :=
suffices α.inv = β.inv by grind [Iso]
calc

View file

@ -68,6 +68,9 @@ structure NatTrans [Category.{v₁, u₁} C] [Category.{v₂, u₂} D] (F G : Fu
/-- The naturality square for a given morphism. -/
naturality : ∀ ⦃X Y : C⦄ (f : X ⟶ Y), F.map f ≫ app Y = app X ≫ G.map f := by grind
attribute [grind ext] funext -- TODO: remove
attribute [grind ext] NatTrans.ext
attribute [simp, grind =] NatTrans.naturality
namespace NatTrans

View file

@ -1,5 +1,6 @@
set_option grind.warning false
attribute [grind ext] funext -- TODO: remove
example (f : (Nat → Nat) → Nat → Nat → Nat) : a = b → f (fun x => a + x) 1 b = f (fun x => b + x) 1 a := by
grind

View file

@ -117,7 +117,7 @@ set_option grind.warning false
-- We first set up some convenient macros for dealing with subtypes using `grind`.
/-- Construct a term of a subtype, using `grind` to discharge the condition. -/
macro "g⟨" a:term "⟩" : term => `(⟨$a, by grind (gen := 9) (splits := 9)⟩)
macro "g⟨" a:term "⟩" : term => `(⟨$a, by grind (gen := 8) (splits := 9)⟩)
/--
Replace a term of a subtype with a term of a different subtype, using the same data,
and using `grind` to discharge the new condition (with access to the old condition).
@ -159,6 +159,8 @@ we are allowed to increase the size of the branches by one, and still be smaller
| var _ => 1
| .ite i t e => 2 * normSize i + max (normSize t) (normSize e) + 1
attribute [grind ext] funext -- TODO remove
set_option profiler true
/--
Normalizes the expression at the same time as
making the variable assignments to literal booleans given by `assign`.

View file

@ -1,4 +1,5 @@
reset_grind_attrs%
set_option grind.warning false
namespace List
@ -18,6 +19,8 @@ theorem getElem!_of_getElem?' [Inhabited α] :
∀ {l : List α} {i : Nat}, l[i]? = some b → l[i]! = b := by
grind
attribute [grind ext] List.ext_getElem?
attribute [local grind =] Option.map_some Option.map_none in
attribute [local grind =] getElem?_map in
attribute [local grind =] getElem?_replicate in

View file

@ -317,6 +317,8 @@ example {α} (f : α → Type) (a : α) (h : ∀ x, Nonempty (f x)) : Nonempty (
example {α β} (f : α → β) (a : α) : ∃ a', f a' = f a := by
grind
attribute [grind ext] List.ext_getElem?
open List in
example : (replicate n a).map f = replicate n (f a) := by
grind +splitIndPred only [Option.map_some, Option.map_none, getElem?_map, getElem?_replicate]
@ -339,6 +341,8 @@ example : (replicate n a).map f = replicate n (f a) := by
a : Nat
b : Bool
attribute [grind ext] S.ext
example (x y : S) : x.a = y.a → y.b = x.b → x = y := by
grind

View file

@ -32,6 +32,7 @@ info: Try this: grind only [= List.length_cons]
example : 0 < (x :: t).length := by
grind?
attribute [grind ext] List.ext_getElem?
/--
info: Try this: grind only [= Option.map_some, = Option.map_none, = List.getElem?_replicate, = List.getElem?_eq_some_iff, =
List.getElem?_map, = List.getElem_replicate, = List.getElem?_eq_none, = List.length_replicate, →