feat: guard and check in grind_pattern (#11428)

This PR implements support for **guards** in `grind_pattern`. The new
feature provides additional control over theorem instantiation. For
example, consider the following monotonicity theorem:

```lean
opaque f : Nat → Nat
theorem fMono : x ≤ y → f x ≤ f y := ...
```

We can use `grind_pattern` to instruct `grind` to instantiate the
theorem for every pair `f x` and `f y` occurring in the goal:

```lean
grind_pattern fMono => f x, f y
```

Then we can automatically prove the following simple example using
`grind`:

```lean
/--
trace: [grind.ematch.instance] fMono: f a ≤ b → f (f a) ≤ f b
[grind.ematch.instance] fMono: f a ≤ c → f (f a) ≤ f c
[grind.ematch.instance] fMono: f a ≤ a → f (f a) ≤ f a
[grind.ematch.instance] fMono: f a ≤ f (f a) → f (f a) ≤ f (f (f a))
[grind.ematch.instance] fMono: f a ≤ f a → f (f a) ≤ f (f a)
[grind.ematch.instance] fMono: f (f a) ≤ b → f (f (f a)) ≤ f b
[grind.ematch.instance] fMono: f (f a) ≤ c → f (f (f a)) ≤ f c
[grind.ematch.instance] fMono: f (f a) ≤ a → f (f (f a)) ≤ f a
[grind.ematch.instance] fMono: f (f a) ≤ f (f a) → f (f (f a)) ≤ f (f (f a))
[grind.ematch.instance] fMono: f (f a) ≤ f a → f (f (f a)) ≤ f (f a)
[grind.ematch.instance] fMono: a ≤ b → f a ≤ f b
[grind.ematch.instance] fMono: a ≤ c → f a ≤ f c
[grind.ematch.instance] fMono: a ≤ a → f a ≤ f a
[grind.ematch.instance] fMono: a ≤ f (f a) → f a ≤ f (f (f a))
[grind.ematch.instance] fMono: a ≤ f a → f a ≤ f (f a)
[grind.ematch.instance] fMono: c ≤ b → f c ≤ f b
[grind.ematch.instance] fMono: c ≤ c → f c ≤ f c
[grind.ematch.instance] fMono: c ≤ a → f c ≤ f a
[grind.ematch.instance] fMono: c ≤ f (f a) → f c ≤ f (f (f a))
[grind.ematch.instance] fMono: c ≤ f a → f c ≤ f (f a)
[grind.ematch.instance] fMono: b ≤ b → f b ≤ f b
[grind.ematch.instance] fMono: b ≤ c → f b ≤ f c
[grind.ematch.instance] fMono: b ≤ a → f b ≤ f a
[grind.ematch.instance] fMono: b ≤ f (f a) → f b ≤ f (f (f a))
[grind.ematch.instance] fMono: b ≤ f a → f b ≤ f (f a)
-/
#guard_msgs in
example : f b = f c → a ≤ f a → f (f a) ≤ f (f (f a)) := by
  set_option trace.grind.ematch.instance true in
  grind
```

However, many unnecessary theorem instantiations are generated.

With the new `guard` feature, we can instruct `grind` to instantiate the
theorem **only if** `x ≤ y` is already known to be true in the current
`grind` state:

```lean
grind_pattern fMono => f x, f y where
  guard x ≤ y
  x =/= y
```

If we run the example again, only three instances are generated:

```lean
/--
trace: [grind.ematch.instance] fMono: a ≤ f a → f a ≤ f (f a)
[grind.ematch.instance] fMono: f a ≤ f (f a) → f (f a) ≤ f (f (f a))
[grind.ematch.instance] fMono: a ≤ f (f a) → f a ≤ f (f (f a))
-/
#guard_msgs in
example : f b = f c → a ≤ f a → f (f a) ≤ f (f (f a)) := by
  set_option trace.grind.ematch.instance true in
  grind
```

Note that `guard` does **not** check whether the expression is
*implied*. It only checks whether the expression is *already known* to
be true in the current `grind` state. If this fact is eventually
learned, the theorem will be instantiated.

If you want `grind` to check whether the expression is implied, you
should use:

```lean
grind_pattern fMono => f x, f y where
  check x ≤ y
  x =/= y
```

Remark: we can use multiple `guard`/`check`s in a `grind_pattern`
command.
This commit is contained in:
Leonardo de Moura 2025-11-28 19:56:53 -08:00 committed by GitHub
parent 3f05179fdb
commit 075f1d66eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 279 additions and 21 deletions

View file

@ -61,6 +61,7 @@ builtin_initialize registerTraceClass `grind.ematch
builtin_initialize registerTraceClass `grind.ematch.pattern
builtin_initialize registerTraceClass `grind.ematch.instance
builtin_initialize registerTraceClass `grind.ematch.instance.assignment
builtin_initialize registerTraceClass `grind.ematch.instance.delayed
builtin_initialize registerTraceClass `grind.eqResolution
builtin_initialize registerTraceClass `grind.issues
builtin_initialize registerTraceClass `grind.simp

View file

@ -240,6 +240,16 @@ where
propagateDown e
propagateUnitConstFuns lams₁ lams₂
toPropagateSolvers.propagate
if rhsNode.root.isTrue then
checkDelayedThmInsts toPropagateDown
checkDelayedThmInsts (toPropagateDown : List Expr) : GoalM Unit := do
if (← isInconsistent) then return ()
if (← get).delayedThmInsts.isEmpty then return ()
for e in toPropagateDown do
let some delayedThms := (← get).delayedThmInsts.find? { expr := e } | pure ()
modify fun s => { s with delayedThmInsts := s.delayedThmInsts.erase { expr := e } }
delayedThms.forM (·.check)
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
let isFalseRoot ← isFalseExpr rootNew
traverseEqc lhs fun n => do

View file

@ -461,7 +461,7 @@ macro "reportEMatchIssue!" s:(interpolatedStr(term) <|> term) : doElem => do
Stores new theorem instance in the state.
Recall that new instances are internalized later, after a full round of ematching.
-/
private def addNewInstance (thm : EMatchTheorem) (proof : Expr) (generation : Nat) : M Unit := do
private def addNewInstance (thm : EMatchTheorem) (proof : Expr) (generation : Nat) (guards : List TheoremGuard) : M Unit := do
let proof ← instantiateMVars proof
if grind.debug.proofs.get (← getOptions) then
check proof
@ -499,8 +499,7 @@ where
-- We must add a hint because `annotateEqnTypeConds` introduces `Grind.PreMatchCond`
-- which is not reducible.
proof := mkExpectedPropHint proof prop
trace_goal[grind.ematch.instance] "{thm.origin.pp}: {prop}"
addTheoremInstance thm proof prop (generation+1)
addTheoremInstance thm proof prop (generation+1) guards
private def synthesizeInsts (mvars : Array Expr) (bis : Array BinderInfo) : OptionT M Unit := do
let thm := (← read).thm
@ -741,7 +740,33 @@ private def checkConstraints (thm : EMatchTheorem) (gen : Nat) (proof : Expr) (a
It may be useful to bound the number of instances in the current branch.
-/
return (← getEMatchTheoremNumInstances thm) + 1 < n
| _ => throwError "NIY"
| .check _ | .guard _ => return true
private def collectGuards (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : GoalM (List TheoremGuard) := do
if thm.cnstrs.isEmpty then return []
/- **Note**: Only top-level theorems have constraints. -/
let .const declName us := proof | return []
unless thm.cnstrs.any fun c => c matches .check _ | .guard _ do return []
let info ← getConstInfo declName
let mut result := #[]
let applySubst (e : Expr) : GoalM (Option Expr) := do
let e := e.instantiateRev args
let e := e.instantiateLevelParams info.levelParams us
let e ← instantiateMVars e
if e.hasMVar then
reportIssue! "guard for `{thm.origin.pp}` was skipped because it contains metavariables after theorem instantiation{indentExpr e}"
return none
return some e
for cnstr in thm.cnstrs do
match cnstr with
| .check e =>
let some e ← applySubst e | pure ()
result := result.push <| { e, check := true }
| .guard e =>
let some e ← applySubst e | pure ()
result := result.push <| { e, check := false }
| _ => pure ()
return result.toList
/--
After processing a (multi-)pattern, use the choice assignment to instantiate the proof.
@ -762,15 +787,16 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
let (some _, c) ← applyAssignment mvars |>.run c | return ()
let some _ ← synthesizeInsts mvars bis | return ()
if (← checkConstraints thm c.gen proof mvars) then
let guards ← collectGuards thm proof mvars
let proof := mkAppN proof mvars
if (← mvars.allM (·.mvarId!.isAssigned)) then
addNewInstance thm proof c.gen
addNewInstance thm proof c.gen guards
else
let mvars ← mvars.filterM fun mvar => return !(← mvar.mvarId!.isAssigned)
if let some mvarBad ← mvars.findM? fun mvar => return !(← isProof mvar) then
reportEMatchIssue! "failed to instantiate {thm.origin.pp}, failed to instantiate non propositional argument with type{indentExpr (← inferType mvarBad)}"
let proof ← mkLambdaFVars (binderInfoForMVars := .default) mvars (← instantiateMVars proof)
addNewInstance thm proof c.gen
addNewInstance thm proof c.gen guards
/-- Process choice stack until we don't have more choices to be processed. -/
private def processChoices : M Unit := do
@ -891,8 +917,19 @@ Recall that the mapping is nonempty only if tracing is enabled.
-/
def ematch' (extraThms : Array EMatchTheorem := #[]) : GoalM (Bool × InstanceMap) := do
let numInstances := (← get).ematch.numInstances
let numDelayedInstances := (← get).ematch.numDelayedInstances
let map ← ematchCore extraThms
return ((← get).ematch.numInstances != numInstances, map)
let progress :=
(← get).ematch.numInstances != numInstances
||
(← get).ematch.numDelayedInstances != numDelayedInstances
if (← get).ematch.numDelayedInstances != numDelayedInstances then
/-
**Note**: If delayed instances were produced, new guards may have been internalized,
and we may have pending facts to process.
-/
processNewFacts
return (progress, map)
/--
Performs one round of E-matching, and returns `true` if new instances were generated.

View file

@ -28,7 +28,7 @@ end GrindCnstr
open GrindCnstr in
def grindPatternCnstr : Parser :=
isValue <|> isStrictValue <|> isGround <|> sizeLt <|> depthLt <|> genLt <|> maxInsts
<|> guard <|> check <|> notDefEq <|> defEq
<|> guard <|> GrindCnstr.check <|> notDefEq <|> defEq
def grindPatternCnstrs : Parser := leading_parser "where " >> many1Indent (ppLine >> grindPatternCnstr)

View file

@ -45,7 +45,8 @@ def dsimpCore (e : Expr) : GrindM Expr := do profileitM Exception "grind dsimp"
Preprocesses `e` using `grind` normalization theorems and simprocs,
and then applies several other preprocessing steps.
-/
def preprocess (e : Expr) : GoalM Simp.Result := do
@[export lean_grind_preprocess]
def preprocessImpl (e : Expr) : GoalM Simp.Result := do
let e ← instantiateMVars e
let r ← simpCore e
/-

View file

@ -121,17 +121,20 @@ inductive SplitSource where
input
| /-- Injectivity theorem. -/
inj (origin : Origin)
| /-- `grind_pattern` guard -/
guard (origin : Origin)
deriving Inhabited
def SplitSource.toMessageData : SplitSource → MessageData
| .ematch origin => m!"E-matching {origin.pp}"
| .ext declName => m!"Extensionality {declName}"
| .ematch origin => m!"E-matching `{origin.pp}`"
| .guard origin => m!"Theorem instantiation guard for `{origin.pp}`"
| .ext declName => m!"Extensionality `{declName}`"
| .mbtc a b i => m!"Model-based theory combination at argument #{i} of{indentExpr a}\nand{indentExpr b}"
| .beta e => m!"Beta-reduction of{indentExpr e}"
| .forallProp e => m!"Forall propagation at{indentExpr e}"
| .existsProp e => m!"Exists propagation at{indentExpr e}"
| .input => "Initial goal"
| .inj origin => m!"Injectivity {origin.pp}"
| .inj origin => m!"Injectivity `{origin.pp}`"
/-- Context for `GrindM` monad. -/
structure Context where
@ -762,8 +765,10 @@ structure EMatch.State where
thms : PArray EMatchTheorem := {}
/-- Active theorems that we have not performed any round of ematching yet. -/
newThms : PArray EMatchTheorem := {}
/-- Number of theorem instances generated so far -/
/-- Number of theorem instances generated so far. -/
numInstances : Nat := 0
/-- Number of delayed theorem instances generated so far. We track them to decide whether E-match made progress or not. -/
numDelayedInstances : Nat := 0
/-- Number of E-matching rounds performed in this goal since the last case-split. -/
num : Nat := 0
/-- (pre-)instances found so far. It includes instances that failed to be instantiated. -/
@ -900,6 +905,30 @@ structure Injective.State where
fns : PHashMap ExprPtr InjectiveInfo := {}
deriving Inhabited
/--
Users can attach guards to `grind_pattern`s. A guard ensures that a theorem is instantiated
only when the guard expression becomes provably true.
If `check` is `true`, then `grind` attempts to prove `e` by asserting its negation and
checking whether this leads to a contradiction.
-/
structure TheoremGuard where
e : Expr
check : Bool
deriving Inhabited
/--
A delayed theorem instantiation is an instantiation that includes one or more guards.
See `TheoremGuard`.
-/
structure DelayedTheoremInstance where
thm : EMatchTheorem
proof : Expr
prop : Expr
generation : Nat
guards : List TheoremGuard
deriving Inhabited
/-- The `grind` goal. -/
structure Goal where
mvarId : MVarId
@ -936,6 +965,11 @@ structure Goal where
clean : Clean.State := {}
/-- Solver states. -/
sstates : Array SolverExtensionState := #[]
/--
Delayed instantiations is a mapping from guards to theorems that are waiting them
to become `True`.
-/
delayedThmInsts : PHashMap ExprPtr (List DelayedTheoremInstance) := {}
deriving Inhabited
def Goal.hasSameRoot (g : Goal) (a b : Expr) : Bool :=
@ -1001,12 +1035,6 @@ def addNewRawFact (proof : Expr) (prop : Expr) (generation : Nat) (splitSource :
def getNumTheoremInstances : GoalM Nat := do
return (← get).ematch.numInstances
/-- Adds a new theorem instance produced using E-matching. -/
def addTheoremInstance (thm : EMatchTheorem) (proof : Expr) (prop : Expr) (generation : Nat) : GoalM Unit := do
saveEMatchTheorem thm
addNewRawFact proof prop generation (.ematch thm.origin)
modify fun s => { s with ematch.numInstances := s.ematch.numInstances + 1 }
/-- Returns `true` if the maximum number of instances has been reached. -/
def checkMaxInstancesExceeded : GoalM Bool := do
return (← get).ematch.numInstances >= (← getConfig).instances
@ -1316,13 +1344,17 @@ It assumes `a` and `b` are in the same equivalence class.
@[extern "lean_grind_mk_heq_proof"]
opaque mkHEqProof (a b : Expr) : GoalM Expr
-- Forward definition
@[extern "lean_grind_process_new_facts"]
opaque processNewFacts : GoalM Unit
-- Forward definition
@[extern "lean_grind_internalize"]
opaque internalize (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit
-- Forward definition
@[extern "lean_grind_process_new_facts"]
opaque processNewFacts : GoalM Unit
@[extern "lean_grind_preprocess"]
opaque preprocess : Expr → GoalM Simp.Result
/--
Internalizes a local declaration which is not a proposition.
@ -1589,6 +1621,45 @@ def addSplitCandidate (sinfo : SplitInfo) : GoalM Unit := do
}
updateSplitArgPosMap sinfo
inductive ActivateNextGuardResult where
| ready
| next (guard : Expr) (pending : List TheoremGuard)
def activateNextGuard (thm : EMatchTheorem) (guards : List TheoremGuard) (generation : Nat) : GoalM ActivateNextGuardResult := do
go guards
where
go : List TheoremGuard → GoalM ActivateNextGuardResult
| [] => return .ready
| guard :: guards => do
let { expr := e, .. } ← preprocess guard.e
internalize e generation
if (← isEqTrue e) then
go guards
else
if guard.check then
addSplitCandidate <| .default e (.guard thm.origin)
return .next e guards
/-- Adds a new theorem instance produced using E-matching. -/
def addTheoremInstance (thm : EMatchTheorem) (proof : Expr) (prop : Expr) (generation : Nat) (guards : List TheoremGuard) : GoalM Unit := do
match (← activateNextGuard thm guards generation) with
| .ready =>
trace_goal[grind.ematch.instance] "{thm.origin.pp}: {prop}"
saveEMatchTheorem thm
addNewRawFact proof prop generation (.ematch thm.origin)
modify fun s => { s with ematch.numInstances := s.ematch.numInstances + 1 }
| .next guard guards =>
let thms := (← get).delayedThmInsts.find? { expr := guard } |>.getD []
let thms := { thm, proof, prop, generation, guards } :: thms
trace_goal[grind.ematch.instance.delayed] "`{thm.origin.pp}` waiting{indentExpr guard}"
modify fun s => { s with
delayedThmInsts := s.delayedThmInsts.insert { expr := guard } thms
ematch.numDelayedInstances := s.ematch.numDelayedInstances + 1
}
def DelayedTheoremInstance.check (delayed : DelayedTheoremInstance) : GoalM Unit := do
addTheoremInstance delayed.thm delayed.proof delayed.prop delayed.generation delayed.guards
/--
Returns extensionality theorems for the given type if available.
If `Config.ext` is `false`, the result is `#[]`.

View file

@ -155,3 +155,141 @@ example
grind
end Ex8
namespace Ex9
opaque f : Nat → Nat → Nat
axiom fax : x ≠ y → f x y > 0
grind_pattern fax => f x y where
guard x ≠ y
/--
trace: [grind.ematch.instance.delayed] `fax` waiting
¬x = y
-/
#guard_msgs (trace, drop error) in
example : f x y = 5 → False := by
set_option trace.grind.ematch.instance true in
set_option trace.grind.ematch.instance.delayed true in
grind
/--
trace: [grind.ematch.instance.delayed] `fax` waiting
¬x = y
[grind.ematch.instance] fax: x ≠ y → f x y > 0
-/
#guard_msgs in
example : x ≠ y → f x y = 0 → False := by
set_option trace.grind.ematch.instance true in
set_option trace.grind.ematch.instance.delayed true in
grind
end Ex9
namespace Ex10
opaque f : Nat → Nat → Nat
axiom fax : x = y → f x y > 0
grind_pattern fax => f x y where
check x = y
/--
trace: [grind.ematch.instance.delayed] `fax` waiting
x = y
[grind.split] x = y, generation: 1
[grind.ematch.instance] fax: x = y → f x y > 0
-/
#guard_msgs (drop error, trace) in
example : f x y = 0 → False := by
set_option trace.grind.ematch.instance true in
set_option trace.grind.ematch.instance.delayed true in
set_option trace.grind.split true in
grind
end Ex10
namespace Ex11
opaque f : Nat → Nat → Nat
axiom fax : x = y → f x y > 0
grind_pattern fax => f x y where
guard x = y
-- `grind` will not case-split on `x = y` since `guard` was used instead of `check`
/--
trace: [grind.ematch.instance.delayed] `fax` waiting
x = y
-/
#guard_msgs (drop error, trace) in
example : f x y = 0 → False := by
set_option trace.grind.ematch.instance true in
set_option trace.grind.ematch.instance.delayed true in
set_option trace.grind.split true in
grind
end Ex11
namespace Ex12
opaque f : Nat → Nat
axiom fMono : x ≤ y → f x ≤ f y
grind_pattern fMono => f x, f y
-- Many unnecessary instances were generated.
/--
trace: [grind.ematch.instance] fMono: f a ≤ b → f (f a) ≤ f b
[grind.ematch.instance] fMono: f a ≤ c → f (f a) ≤ f c
[grind.ematch.instance] fMono: f a ≤ a → f (f a) ≤ f a
[grind.ematch.instance] fMono: f a ≤ f (f a) → f (f a) ≤ f (f (f a))
[grind.ematch.instance] fMono: f a ≤ f a → f (f a) ≤ f (f a)
[grind.ematch.instance] fMono: f (f a) ≤ b → f (f (f a)) ≤ f b
[grind.ematch.instance] fMono: f (f a) ≤ c → f (f (f a)) ≤ f c
[grind.ematch.instance] fMono: f (f a) ≤ a → f (f (f a)) ≤ f a
[grind.ematch.instance] fMono: f (f a) ≤ f (f a) → f (f (f a)) ≤ f (f (f a))
[grind.ematch.instance] fMono: f (f a) ≤ f a → f (f (f a)) ≤ f (f a)
[grind.ematch.instance] fMono: a ≤ b → f a ≤ f b
[grind.ematch.instance] fMono: a ≤ c → f a ≤ f c
[grind.ematch.instance] fMono: a ≤ a → f a ≤ f a
[grind.ematch.instance] fMono: a ≤ f (f a) → f a ≤ f (f (f a))
[grind.ematch.instance] fMono: a ≤ f a → f a ≤ f (f a)
[grind.ematch.instance] fMono: c ≤ b → f c ≤ f b
[grind.ematch.instance] fMono: c ≤ c → f c ≤ f c
[grind.ematch.instance] fMono: c ≤ a → f c ≤ f a
[grind.ematch.instance] fMono: c ≤ f (f a) → f c ≤ f (f (f a))
[grind.ematch.instance] fMono: c ≤ f a → f c ≤ f (f a)
[grind.ematch.instance] fMono: b ≤ b → f b ≤ f b
[grind.ematch.instance] fMono: b ≤ c → f b ≤ f c
[grind.ematch.instance] fMono: b ≤ a → f b ≤ f a
[grind.ematch.instance] fMono: b ≤ f (f a) → f b ≤ f (f (f a))
[grind.ematch.instance] fMono: b ≤ f a → f b ≤ f (f a)
-/
#guard_msgs in
example : f b = f c → a ≤ f a → f (f a) ≤ f (f (f a)) := by
set_option trace.grind.ematch.instance true in
grind
end Ex12
namespace Ex13
-- Same example but using constraints to control theorem/axiom instantiation
opaque f : Nat → Nat
axiom fMono : x ≤ y → f x ≤ f y
grind_pattern fMono => f x, f y where
guard x ≤ y
x =/= y
/--
trace: [grind.ematch.instance] fMono: a ≤ f a → f a ≤ f (f a)
[grind.ematch.instance] fMono: f a ≤ f (f a) → f (f a) ≤ f (f (f a))
[grind.ematch.instance] fMono: a ≤ f (f a) → f a ≤ f (f (f a))
-/
#guard_msgs in
example : f b = f c → a ≤ f a → f (f a) ≤ f (f (f a)) := by
set_option trace.grind.ematch.instance true in
grind
end Ex13