fix: [grind inj] attribute (#10482)

This PR fixes symbol collection for the `@[grind inj]` attribute.
This commit is contained in:
Leonardo de Moura 2025-09-20 21:14:17 -07:00 committed by GitHub
parent 5f68c1662d
commit 42be7bb5c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 42 additions and 5 deletions

View file

@ -5,7 +5,8 @@ Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.Theorems
public import Lean.Meta.Tactic.Grind.EMatchTheorem
import Lean.Meta.FunInfo
public section
namespace Lean.Meta.Grind
@ -37,16 +38,33 @@ private builtin_initialize injectiveTheoremsExt : SimpleScopedEnvExtension Injec
initial := {}
}
private def getSymbols (proof : Expr) : MetaM (List HeadIndex) := do
private partial def getSymbols (proof : Expr) : MetaM (List HeadIndex) := do
let type ← inferType proof
forallTelescope type fun _ type => do
unless type.isAppOfArity ``Function.Injective 3 do
throwError "invalid `[grind inj]` theorem, resulting type is not of the form `Function.Injective <fun>`{indentExpr type}"
let f := type.appArg!
let cs : NameSet := f.foldConsts (init := {}) fun declName s => s.insert declName
let f := type.appArg!.eta
let cs ← collectFnNames f
if cs.isEmpty then
throwError "invalid `[grind inj]` theorem, injective function must use at least one constant symbol{indentExpr f}"
throwError "invalid `[grind inj]` theorem, injective function must use at least one constant function symbol{indentExpr f}"
return cs.toList.map (.const ·)
where
collectFnNames (f : Expr) : MetaM NameSet := do
if let .const declName _ := f then
return { declName }
else
Prod.snd <$> (go f |>.run {})
go (e : Expr) : StateRefT NameSet MetaM Unit := do
if e.isApp then e.withApp fun f args => do
if let .const declName _ := f then
modify (·.insert declName)
let kinds ← NormalizePattern.getPatternArgKinds f args.size
for h : i in *...args.size do
let arg := args[i]
let kind := kinds[i]?.getD .relevant
if kind matches .relevant | .typeFormer then
go arg
private def symbolsToNames (s : List HeadIndex) : List Name :=
s.map fun

View file

@ -45,3 +45,22 @@ error: invalid `[grind inj]` theorem, resulting type is not of the form `Functio
#guard_msgs in
@[grind inj] theorem succ_inj' : succ x = succ y → x = y := by
grind [succ]
/-- trace: [grind.inj] mul_2_inj: [HMul.hMul, OfNat.ofNat] -/
#guard_msgs in
set_option trace.grind.inj true in
@[grind inj] theorem mul_2_inj : Function.Injective (2 * ·) := by
grind [Function.Injective]
def Array.IsId (as : Array Nat) : Prop :=
∀ i : Fin as.size, as[i] = i
/-- trace: [grind.inj] array_inj: [Array, GetElem?.getElem?, Fin, Array.size] -/
#guard_msgs in
set_option trace.grind.inj true in
@[grind inj] theorem array_inj {as : Array Nat} (h : as.IsId) : Function.Injective (as[·]? : Fin as.size → Option Nat) := by
intro a b; simp
have ha := h a
have hb := h b
simp at ha hb
grind