feat: instantiate tactic parameters (#10746)

This PR implements parameters for the `instantiate` tactic in the
`grind` interactive mode. Users can now select both global and local
theorems. Local theorems are selected using anchors. It also adds the
`show_thms` tactic for displaying local theorems. Example:

```lean
example (as bs cs : Array α) (v₁ v₂ : α)
        (i₁ i₂ j : Nat)
        (h₁ : i₁ < as.size)
        (h₂ : bs = as.set i₁ v₁)
        (h₃ : i₂ < bs.size)
        (h₃ : cs = bs.set i₂ v₂)
        (h₄ : i₁ ≠ j ∧ i₂ ≠ j)
        (h₅ : j < cs.size)
        (h₆ : j < as.size)
        : cs[j] = as[j] := by
  grind =>
    instantiate = Array.getElem_set
    instantiate Array.getElem_set
```
This commit is contained in:
Leonardo de Moura 2025-10-11 14:35:21 -07:00 committed by GitHub
parent 0dc862e3ed
commit 4f7d3bb692
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 235 additions and 30 deletions

View file

@ -6,8 +6,18 @@ Authors: Leonardo de Moura
module
prelude
public import Init.Tactics
public import Init.Grind.Attr
public section
namespace Lean.Parser.Tactic.Grind
namespace Lean.Parser.Tactic
syntax grindLemma := ppGroup((Attr.grindMod ppSpace)? ident)
/--
The `!` modifier instructs `grind` to consider only minimal indexable subexpressions
when selecting patterns.
-/
syntax grindLemmaMin := ppGroup("!" (Attr.grindMod ppSpace)? ident)
namespace Grind
/-- `grind` is the syntax category for a "grind interactive tactic".
A `grind` tactic is a program which receives a `grind` goal. -/
@ -35,8 +45,11 @@ syntax (name := linarith) "linarith" : grind
/-- The `sorry` tactic is a temporary placeholder for an incomplete tactic proof. -/
syntax (name := «sorry») "sorry" : grind
syntax anchor := "#" noWs hexnum
syntax thm := anchor <|> grindLemma <|> grindLemmaMin
/-- Instantiates theorems using E-matching. -/
syntax (name := instantiate) "instantiate" : grind
syntax (name := instantiate) "instantiate" (colGt thm),* : grind
declare_syntax_cat show_filter (behavior := both)
@ -58,21 +71,23 @@ syntax showFilter := (colGt show_filter)?
-- **Note**: Should we rename the following tactics to `trace_`?
/-- Shows asserted facts. -/
syntax (name := showAsserted) "show_asserted " showFilter : grind
syntax (name := showAsserted) "show_asserted" ppSpace showFilter : grind
/-- Shows propositions known to be `True`. -/
syntax (name := showTrue) "show_true " showFilter : grind
syntax (name := showTrue) "show_true" ppSpace showFilter : grind
/-- Shows propositions known to be `False`. -/
syntax (name := showFalse) "show_false " showFilter : grind
syntax (name := showFalse) "show_false" ppSpace showFilter : grind
/-- Shows equivalence classes of terms. -/
syntax (name := showEqcs) "show_eqcs " showFilter : grind
syntax (name := showEqcs) "show_eqcs" ppSpace showFilter : grind
/-- Show case-split candidates. -/
syntax (name := showSplits) "show_splits " showFilter : grind
syntax (name := showSplits) "show_splits" ppSpace showFilter : grind
/-- Show `grind` state. -/
syntax (name := «showState») "show_state " showFilter : grind
syntax (name := «showState») "show_state" ppSpace showFilter : grind
/-- Show active local theorems and their anchors for heuristic instantiation. -/
syntax (name := showThms) "show_thms" : grind
declare_syntax_cat grind_ref (behavior := both)
syntax:max "#" noWs hexnum : grind_ref
syntax:max anchor : grind_ref
syntax term : grind_ref
syntax (name := cases) "cases " grind_ref (" with " (colGt ident)+)? : grind
@ -143,4 +158,5 @@ macro "admit" : grind => `(grind| sorry)
/-- `fail msg` is a tactic that always fails, and produces an error using the given message. -/
syntax (name := fail) "fail" (ppSpace str)? : grind
end Lean.Parser.Tactic.Grind
end Grind
end Lean.Parser.Tactic

View file

@ -6,7 +6,6 @@ Authors: Leonardo de Moura
module
prelude
public import Init.Core
public import Init.Grind.Attr
public import Init.Grind.Interactive
public section
namespace Lean.Grind
@ -209,14 +208,11 @@ namespace Lean.Parser.Tactic
/-!
`grind` tactic and related tactics.
-/
syntax grindErase := "-" ident
syntax grindLemma := ppGroup((Attr.grindMod ppSpace)? ident)
/--
The `!` modifier instructs `grind` to consider only minimal indexable subexpressions
when selecting patterns.
-/
syntax grindLemmaMin := ppGroup("!" (Attr.grindMod ppSpace)? ident)
syntax grindParam := grindErase <|> grindLemma <|> grindLemmaMin
open Parser.Tactic.Grind

View file

@ -102,8 +102,8 @@ def evalCheck (tacticName : Name) (k : GoalM Bool)
@[builtin_grind_tactic ac] def evalAC : GrindTactic := fun _ => do
evalCheck `ac AC.check AC.pp?
@[builtin_grind_tactic instantiate] def evalInstantiate : GrindTactic := fun _ => do
let progress ← liftGoalM <| ematch
def ematchThms (thms : Array EMatchTheorem) : GrindTacticM Unit := do
let progress ← liftGoalM <| if thms.isEmpty then ematch else ematchTheorems thms
unless progress do
throwError "`instantiate` tactic failed to instantiate new facts, use `show_patterns` to see active theorems and their patterns."
let goal ← getMainGoal
@ -112,14 +112,110 @@ def evalCheck (tacticName : Name) (k : GoalM Bool)
getGoal
replaceMainGoal [goal]
def elabAnchor (anchor : TSyntax `hexnum) : CoreM (Nat × UInt64) := do
let numDigits := anchor.getHexNumSize
let val := anchor.getHexNumVal
if val >= UInt64.size then
throwError "invalid anchor, value is too big"
let val := val.toUInt64
return (numDigits, val)
@[builtin_grind_tactic instantiate] def evalInstantiate : GrindTactic := fun stx => withMainContext do
match stx with
| `(grind| instantiate $[$thmRefs:thm],*) =>
let mut thms := #[]
for thmRef in thmRefs do
match thmRef with
| `(Parser.Tactic.Grind.thm| #$anchor:hexnum) => thms := thms ++ (← withRef thmRef <| elabLocalEMatchTheorem anchor)
| `(Parser.Tactic.Grind.thm| $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id false)
| `(Parser.Tactic.Grind.thm| ! $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id true)
| _ => throwErrorAt thmRef "unexpected theorem reference"
ematchThms thms
| _ => throwUnsupportedSyntax
where
collectThms (numDigits : Nat) (anchorPrefix : UInt64) (thms : PArray EMatchTheorem) : StateT (Array EMatchTheorem) GrindM Unit := do
for thm in thms do
-- **Note**: `anchors` are cached using pointer addresses, if this is a performance issue, we should
-- cache the theorem types.
let type ← inferType thm.proof
let anchor ← getAnchor type
if isAnchorPrefix numDigits anchorPrefix anchor then
modify (·.push thm)
elabLocalEMatchTheorem (anchor : TSyntax `hexnum) : GrindTacticM (Array EMatchTheorem) := do
let (numDigits, anchorPrefix) ← elabAnchor anchor
let goal ← getMainGoal
let thms ← liftGrindM do StateT.run' (s := #[]) do
collectThms numDigits anchorPrefix goal.ematch.thms
collectThms numDigits anchorPrefix goal.ematch.newThms
get
if thms.isEmpty then
throwError "no local theorems"
return thms
ensureNoMinIndexable (minIndexable : Bool) : MetaM Unit := do
if minIndexable then
throwError "redundant modifier `!` in `grind` parameter"
elabEMatchTheorem (declName : Name) (kind : Grind.EMatchTheoremKind) (minIndexable : Bool) : GrindTacticM (Array EMatchTheorem) := do
let params := (← read).params
let info ← getAsyncConstInfo declName
match info.kind with
| .thm | .axiom | .ctor =>
match kind with
| .eqBoth gen =>
ensureNoMinIndexable minIndexable
let thm₁ ← Grind.mkEMatchTheoremForDecl declName (.eqLhs gen) params.symPrios
let thm₂ ← Grind.mkEMatchTheoremForDecl declName (.eqRhs gen) params.symPrios
return #[thm₁, thm₂]
| _ =>
if kind matches .eqLhs _ | .eqRhs _ then
ensureNoMinIndexable minIndexable
let thm ← Grind.mkEMatchTheoremForDecl declName kind params.symPrios (minIndexable := minIndexable)
return #[thm]
| .defn =>
if (← isReducible declName) then
throwError "`{.ofConstName declName}` is a reducible definition, `grind` automatically unfolds them"
if !kind.isEqLhs && !kind.isDefault then
throwError "invalid `grind` parameter, `{.ofConstName declName}` is a definition, the only acceptable (and redundant) modifier is '='"
ensureNoMinIndexable minIndexable
let some thms ← Grind.mkEMatchEqTheoremsForDef? declName
| throwError "failed to generate equation theorems for `{.ofConstName declName}`"
return thms
| _ =>
throwError "invalid `grind` parameter, `{.ofConstName declName}` is not a theorem, definition, or inductive type"
elabThm
(mod? : Option (TSyntax `Lean.Parser.Attr.grindMod))
(id : TSyntax `ident)
(minIndexable : Bool) : GrindTacticM (Array EMatchTheorem) := do
let declName ← realizeGlobalConstNoOverloadWithInfo id
let kind ← if let some mod := mod? then Grind.getAttrKindCore mod else pure .infer
match kind with
| .ematch .user =>
ensureNoMinIndexable minIndexable
let s ← Grind.getEMatchTheorems
let thms := s.find (.decl declName)
let thms := thms.filter fun thm => thm.kind == .user
if thms.isEmpty then
throwError "invalid use of `usr` modifier, `{.ofConstName declName}` does not have patterns specified with the command `grind_pattern`"
return thms.toArray
| .ematch kind =>
elabEMatchTheorem declName kind minIndexable
| .infer =>
let goal ← getMainGoal
let thms := goal.ematch.thmMap.find (.decl declName)
if thms.isEmpty then
elabEMatchTheorem declName (.default false) minIndexable
else
return thms.toArray
| .cases _ | .intro | .inj | .ext | .symbol _ =>
throwError "invalid modifier"
@[builtin_grind_tactic cases] def evalCases : GrindTactic := fun stx => do
match stx with
| `(grind| cases #$anchor:hexnum) =>
let numDigits := anchor.getHexNumSize
let val := anchor.getHexNumVal
if val >= UInt64.size then
throwError "invalid anchor, value is too big"
let val := val.toUInt64
let (numDigits, val) ← elabAnchor anchor
let goal ← getMainGoal
let candidates := goal.split.candidates
let (goals, genNew) ← liftSearchM do

View file

@ -160,27 +160,27 @@ def pushIfSome (msgs : Array MessageData) (msg? : Option MessageData) : Array Me
let filter ← elabFilter filter?
let msgs := #[]
let msgs := pushIfSome msgs (← ppAsserted? filter (collapsed := true))
let msgs := pushIfSome msgs (← ppProps? filter true (collapsed := true))
let msgs := pushIfSome msgs (← ppProps? filter false (collapsed := true))
let msgs := pushIfSome msgs (← ppEqcs? filter (collapsed := true))
let msgs := pushIfSome msgs (← ppProps? filter true (collapsed := false))
let msgs := pushIfSome msgs (← ppProps? filter false (collapsed := false))
let msgs := pushIfSome msgs (← ppEqcs? filter (collapsed := false))
logInfo <| MessageData.trace { cls := `grind, collapsed := false } "Grind state" msgs
| _ => throwUnsupportedSyntax
def truncateAnchors (es : Array (Expr × UInt64)) : Array (Expr × UInt64) × Nat :=
def truncateAnchors (es : Array (UInt64 × α)) : Array (UInt64 × α) × Nat :=
go 4
where
go (numDigits : Nat) : Array (Expr × UInt64) × Nat := Id.run do
go (numDigits : Nat) : Array (UInt64 × α) × Nat := Id.run do
if 4*numDigits < 64 then
let shift := 64 - 4*numDigits
let mut found : Std.HashSet UInt64 := {}
let mut result := #[]
for (e, a) in es do
for (a, e) in es do
let a' := a >>> shift.toUInt64
if found.contains a' then
return (← go (numDigits+1))
else
found := found.insert a'
result := result.push (e, a')
result := result.push (a', e)
return (result, numDigits)
else
return (es, numDigits)
@ -211,14 +211,42 @@ def anchorToString (numDigits : Nat) (anchor : UInt64) : String :=
filter.eval e
-- **TODO**: Add an option for including propositions that are only considered when using `+splitImp`
-- **TODO**: Add an option for including terms whose type is an inductive predicate or type
let candidates := candidates.map fun (e, _, anchor) => (e, anchor)
let candidates := candidates.map fun (e, _, anchor) => (anchor, e)
let (candidates, numDigits) := truncateAnchors candidates
if candidates.isEmpty then
throwError "no case splits"
let msgs := candidates.map fun (e, a) =>
let msgs := candidates.map fun (a, e) =>
.trace { cls := `split } m!"#{anchorToString numDigits a} := {e}" #[]
let msg := MessageData.trace { cls := `splits, collapsed := false } "Case split candidates" msgs
logInfo msg
| _ => throwUnsupportedSyntax
@[builtin_grind_tactic showThms] def evalShowThms : GrindTactic := fun _ => withMainContext do
let goal ← getMainGoal
let entries ← liftGrindM do
let (found, entries) ← go {} {} goal.ematch.thms
let (_, entries) ← go found entries goal.ematch.newThms
pure entries
let (entries, numDigits) := truncateAnchors entries
let msgs := entries.map fun (a, e) =>
.trace { cls := `thm } m!"#{anchorToString numDigits a} := {e}" #[]
let msg := MessageData.trace { cls := `thms, collapsed := false } "Local theorems" msgs
logInfo msg
where
go (found : Std.HashSet Grind.Origin) (result : Array (UInt64 × Expr)) (thms : PArray EMatchTheorem)
: GrindM (Std.HashSet Grind.Origin × Array (UInt64 × Expr)) := do
let mut found := found
let mut result := result
for thm in thms do
-- **Note**: We only display local theorems
if thm.origin matches .local _ | .fvar _ then
unless found.contains thm.origin do
found := found.insert thm.origin
let type ← inferType thm.proof
-- **Note**: Evaluate how stable these anchors are.
let anchor ← getAnchor type
result := result.push (anchor, type)
pure ()
return (found, result)
end Lean.Elab.Tactic.Grind

View file

@ -727,4 +727,16 @@ def ematch : GoalM Bool := do
ematchCore
return (← get).ematch.numInstances != numInstances
/-- Performs one round of E-matching using the giving theorems, and returns `true` if new instances were generated. -/
def ematchTheorems (thms : Array EMatchTheorem) : GoalM Bool := do
let numInstances := (← get).ematch.numInstances
go |>.run'
return (← get).ematch.numInstances != numInstances
where
go : EMatch.M Unit := do profileitM Exception "grind ematch" (← getOptions) do
withReader (fun ctx => { ctx with useMT := false }) do
if (← checkMaxInstancesExceeded <||> checkMaxEmatchExceeded) then
return ()
thms.forM ematchTheorem
end Lean.Meta.Grind

View file

@ -347,3 +347,60 @@ example {y z x : Int} : y = (z+1)*2 → x*y = 1 → x = 0 := by
grind -verbose =>
ring
sorry
example (as bs cs : Array α) (v₁ v₂ : α)
(i₁ i₂ j : Nat)
(h₁ : i₁ < as.size)
(h₂ : bs = as.set i₁ v₁)
(h₃ : i₂ < bs.size)
(h₃ : cs = bs.set i₂ v₂)
(h₄ : i₁ ≠ j ∧ i₂ ≠ j)
(h₅ : j < cs.size)
(h₆ : j < as.size)
: cs[j] = as[j] := by
grind =>
instantiate Array.getElem_set
instantiate Array.getElem_set
example (as bs cs : Array α) (v₁ v₂ : α)
(i₁ i₂ j : Nat)
(h₁ : i₁ < as.size)
(h₂ : bs = as.set i₁ v₁)
(h₃ : i₂ < bs.size)
(h₃ : cs = bs.set i₂ v₂)
(h₄ : i₁ ≠ j ∧ i₂ ≠ j)
(h₅ : j < cs.size)
(h₆ : j < as.size)
: cs[j] = as[j] := by
grind =>
instantiate = Array.getElem_set
instantiate ← Array.getElem_set
opaque p : Nat → Prop
opaque q : Nat → Prop
opaque f : Nat → Nat
opaque finv : Nat → Nat
axiom pq : p x → q x
axiom fInj : finv (f x) = x
example : f x = f y → p x → q y := by
grind =>
instantiate →pq, !fInj
/--
trace: [thms] Local theorems
[thm] #c5bb := ∀ (x : Nat), q x
[thm] #bfb8 := ∀ (x : Nat), p x → p (f x)
-/
#guard_msgs in
example : (∀ x, q x) → (∀ x, p x → p (f x)) → p x → p (f (f x)) := by
grind =>
show_thms
instantiate #bfb8
/-- error: no local theorems -/
#guard_msgs in
example : (∀ x, q x) → (∀ x, p x → p (f x)) → p x → p (f (f x)) := by
grind =>
instantiate #abcd