feat: new E-matching pattern inference for grind (#10342)

This PR implements a new E-matching pattern inference procedure that is
faithful to the behavior documented in the reference manual regarding
minimal indexable subexpressions. The old inference procedure was
failing to enforce this condition. For example, the manual documents
`[grind ->]` as follows

`[@grind →]` selects a multi-pattern from the hypotheses of the theorem.
In other words, `grind` will use the theorem for forwards reasoning.

To generate a pattern, it traverses the hypotheses of the theorem from
left to right. Each time it encounters a **minimal indexable
subexpression** which covers an argument which was not previously
covered, it adds that subexpression as a pattern, until all arguments
have been covered.

That said, the new procedure is currently disabled, and the following
option must be used to enable it.
```
set_option backward.grind.inferPattern false
```
Users can inspect differences between the old a new procedures using the
option
```
set_option backward.grind.checkInferPatternDiscrepancy true 
```
Example:
```lean
/--
warning: found discrepancy between old and new `grind` pattern inference procedures, old:
  [@List.length #2 (@toList _ #1 #0)]
new:
  [@toList #2 #1 #0]
use `set_option backward.grind.inferPattern true` to force old procedure
-/
#guard_msgs in
set_option backward.grind.checkInferPatternDiscrepancy true in
@[grind] theorem Vector.length_toList' (xs : Vector α n) : xs.toList.length = n := by sorry
```
This commit is contained in:
Leonardo de Moura 2025-09-10 22:27:11 -07:00 committed by GitHub
parent c3667e2861
commit 6b387da032
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 152 additions and 6 deletions

View file

@ -1009,6 +1009,7 @@ where
| .bvar idx => modify fun s => if s.contains idx then s else idx :: s
| _ => return ()
namespace OldCollector
private def diff (s : List Nat) (found : Std.HashSet Nat) : List Nat :=
if found.isEmpty then s else s.filter fun x => !found.contains x
@ -1069,19 +1070,129 @@ private partial def collect (e : Expr) : CollectorM Unit := do
collect b
| _ => return ()
end OldCollector
private def sizeOfDiff (s₁ s₂ : Std.HashSet Nat) : Nat :=
s₂.fold (init := s₁.size) fun num idx =>
if s₁.contains idx then num - 1 else num
/--
Normalizes `e` if it qualifies as a candidate pattern, and returns
`some p` where `p` is the normalized pattern.
`argKinds == NormalizePattern.getPatternArgKinds e.getAppFn e.getAppNumArgs`
-/
private def normalizePattern? (e : Expr) (argKinds : Array NormalizePattern.PatternArgKind) : CollectorM (Option Expr) := do
let p := e.abstract (← read).xs
unless p.hasLooseBVars do
trace[grind.debug.ematch.pattern] "skip, does not contain pattern variables"
return none
-- Normalization state before normalizing `e`
let stateBefore ← getThe NormalizePattern.State
let failed : CollectorM (Option Expr) := do
set stateBefore
return none
-- Returns the number of new variables with respect to `saved`
let getNumNewBVars : NormalizePattern.M Nat := do
return sizeOfDiff (← get).bvarsFound stateBefore.bvarsFound
try
let p ← NormalizePattern.normalizePattern p
let stateAfter ← getThe NormalizePattern.State
let numNewBVars ← getNumNewBVars
if numNewBVars == 0 then
trace[grind.debug.ematch.pattern] "skip, no new variables covered"
return (← failed)
/-
Checks whether one of `e`s children subsumes it. We say a child `c` subsumes `e`
1- `e` and `c` have the same new pattern variables. We say a pattern variable is new if it is not in `stateOld.bvarsFound`.
2- `c` is not a support argument. See `NormalizePattern.getPatternSupportMask` for definition.
3- `c` is not an offset pattern.
4- `c` is not a bound variable.
5- `c` is also a candidate.
-/
for arg in e.getAppArgs, argKind in argKinds do
unless argKind.isSupport do
unless arg.isFVar do
unless isOffsetPattern? arg |>.isSome do
if (← isPatternFnCandidate arg.getAppFn) then
let pArg := arg.abstract (← read).xs
set stateBefore
discard <| NormalizePattern.normalizePattern pArg
let numArgNewBVars ← getNumNewBVars
if numArgNewBVars == numNewBVars then
trace[grind.debug.ematch.pattern] "skip, subsumed by argument"
return (← failed)
set stateAfter
return some p
catch ex =>
trace[grind.debug.ematch.pattern] "skip, exception during normalization{indentD ex.toMessageData}"
failed
private partial def collect (e : Expr) : CollectorM Unit := do
if (← get).done then return ()
match e with
| .app .. =>
trace[grind.debug.ematch.pattern] "collect: {e}"
let f := e.getAppFn
let argKinds ← NormalizePattern.getPatternArgKinds f e.getAppNumArgs
if (← isPatternFnCandidate f) then
trace[grind.debug.ematch.pattern] "candidate: {e}"
if let some p ← normalizePattern? e argKinds then
addNewPattern p
return ()
let args := e.getAppArgs
for arg in args, argKind in argKinds do
unless isOffsetPattern? arg |>.isSome do
trace[grind.debug.ematch.pattern] "arg: {arg}, support: {argKind.isSupport}"
unless argKind.isSupport do
collect arg
| .forallE _ d b _ =>
if (← pure e.isArrow <&&> isProp d <&&> isProp b) then
collect d
collect b
| _ => return ()
register_builtin_option backward.grind.inferPattern : Bool := {
defValue := true
group := "backward compatibility"
descr := "use old E-matching pattern inference"
}
register_builtin_option backward.grind.checkInferPatternDiscrepancy : Bool := {
defValue := false
group := "backward compatibility"
descr := "check whether old and new pattern inference procedures infer the same pattern"
}
private def collectPatterns? (proof : Expr) (xs : Array Expr) (searchPlaces : Array Expr) (symPrios : SymbolPriorities) (minPrio : Nat)
: MetaM (Option (List Expr × List HeadIndex)) := do
let go : CollectorM (Option (List Expr)) := do
let go (useOld : Bool): CollectorM (Option (List Expr)) := do
for place in searchPlaces do
trace[grind.debug.ematch.pattern] "place: {place}"
let place ← preprocessPattern place
collect place
if useOld then
OldCollector.collect place
else
collect place
if (← get).done then
return some ((← get).patterns.toList)
return none
let (some ps, s) ← go { proof, xs } |>.run' {} { symPrios, minPrio } |>.run {}
| return none
return some (ps, s.symbols.toList)
let collect? (useOld : Bool) : MetaM (Option (List Expr × List HeadIndex)) := do
let (some ps, s) ← go useOld { proof, xs } |>.run' {} { symPrios, minPrio } |>.run {}
| return none
return some (ps, s.symbols.toList)
let useOld := backward.grind.inferPattern.get (← getOptions)
if backward.grind.checkInferPatternDiscrepancy.get (← getOptions) then
let oldResult? ← collect? (useOld := true)
let newResult? ← collect? (useOld := false)
let toPattern (result? : Option (List Expr × List HeadIndex)) : List MessageData :=
let pat := result?.map (·.1) |>.getD []
pat.map ppPattern
if oldResult? != newResult? then
logWarning m!"found discrepancy between old and new `grind` pattern inference procedures, old:{indentD (toPattern oldResult?)}\nnew:{indentD (toPattern newResult?)}"
return if useOld then oldResult? else newResult?
else
collect? useOld
/--
Tries to find a ground pattern to activate the theorem.

View file

@ -1,5 +1,5 @@
// update me!
#include "util/options.h"
// Please update stage0
namespace lean {
options get_default_options() {
options opts;

View file

@ -0,0 +1,35 @@
set_option warn.sorry false
namespace Test
set_option backward.grind.inferPattern false -- Force new pattern inference procedure
inductive Vector (α : Type) : Nat → Type where
| nil : Vector α 0
| cons (x : α) (xs : Vector α n) : Vector α (n + 1)
def Vector.ofList (xs : List α) (h : xs.length = n) : Vector α n :=
match n, xs with
| 0, [] => .nil
| (n + 1), x :: xs => .cons x (.ofList xs (by grind))
def Vector.toList (xs : Vector α n) : List α :=
match xs with
| .nil => []
| .cons x xs => x :: xs.toList
/-- info: length_toList: [@toList #2 #1 #0] -/
#guard_msgs (info) in
@[grind?] theorem Vector.length_toList (xs : Vector α n) : xs.toList.length = n := by sorry
def wrapper (f : Nat → Nat → List α → List α) (h : ∀ n m xs, xs.length = n → (f n m xs).length = m) :
(n m : Nat) → Vector α n → Vector α m :=
fun n m xs => Vector.ofList (f n m xs.toList) (by grind) -- apply h; apply Vector.length_toList) -- fails here: (by grind)
/--
warning: found discrepancy between old and new `grind` pattern inference procedures, old:
[@List.length #2 (@toList _ #1 #0)]
new:
[@toList #2 #1 #0]
-/
#guard_msgs in
set_option backward.grind.checkInferPatternDiscrepancy true in
@[grind] theorem Vector.length_toList' (xs : Vector α n) : xs.toList.length = n := by sorry