feat: new E-matching pattern inference for grind (#10342)
This PR implements a new E-matching pattern inference procedure that is faithful to the behavior documented in the reference manual regarding minimal indexable subexpressions. The old inference procedure was failing to enforce this condition. For example, the manual documents `[grind ->]` as follows `[@grind →]` selects a multi-pattern from the hypotheses of the theorem. In other words, `grind` will use the theorem for forwards reasoning. To generate a pattern, it traverses the hypotheses of the theorem from left to right. Each time it encounters a **minimal indexable subexpression** which covers an argument which was not previously covered, it adds that subexpression as a pattern, until all arguments have been covered. That said, the new procedure is currently disabled, and the following option must be used to enable it. ``` set_option backward.grind.inferPattern false ``` Users can inspect differences between the old a new procedures using the option ``` set_option backward.grind.checkInferPatternDiscrepancy true ``` Example: ```lean /-- warning: found discrepancy between old and new `grind` pattern inference procedures, old: [@List.length #2 (@toList _ #1 #0)] new: [@toList #2 #1 #0] use `set_option backward.grind.inferPattern true` to force old procedure -/ #guard_msgs in set_option backward.grind.checkInferPatternDiscrepancy true in @[grind] theorem Vector.length_toList' (xs : Vector α n) : xs.toList.length = n := by sorry ```
This commit is contained in:
parent
c3667e2861
commit
6b387da032
3 changed files with 152 additions and 6 deletions
|
|
@ -1009,6 +1009,7 @@ where
|
|||
| .bvar idx => modify fun s => if s.contains idx then s else idx :: s
|
||||
| _ => return ()
|
||||
|
||||
namespace OldCollector
|
||||
private def diff (s : List Nat) (found : Std.HashSet Nat) : List Nat :=
|
||||
if found.isEmpty then s else s.filter fun x => !found.contains x
|
||||
|
||||
|
|
@ -1069,19 +1070,129 @@ private partial def collect (e : Expr) : CollectorM Unit := do
|
|||
collect b
|
||||
| _ => return ()
|
||||
|
||||
end OldCollector
|
||||
|
||||
private def sizeOfDiff (s₁ s₂ : Std.HashSet Nat) : Nat :=
|
||||
s₂.fold (init := s₁.size) fun num idx =>
|
||||
if s₁.contains idx then num - 1 else num
|
||||
|
||||
/--
|
||||
Normalizes `e` if it qualifies as a candidate pattern, and returns
|
||||
`some p` where `p` is the normalized pattern.
|
||||
|
||||
`argKinds == NormalizePattern.getPatternArgKinds e.getAppFn e.getAppNumArgs`
|
||||
-/
|
||||
private def normalizePattern? (e : Expr) (argKinds : Array NormalizePattern.PatternArgKind) : CollectorM (Option Expr) := do
|
||||
let p := e.abstract (← read).xs
|
||||
unless p.hasLooseBVars do
|
||||
trace[grind.debug.ematch.pattern] "skip, does not contain pattern variables"
|
||||
return none
|
||||
-- Normalization state before normalizing `e`
|
||||
let stateBefore ← getThe NormalizePattern.State
|
||||
let failed : CollectorM (Option Expr) := do
|
||||
set stateBefore
|
||||
return none
|
||||
-- Returns the number of new variables with respect to `saved`
|
||||
let getNumNewBVars : NormalizePattern.M Nat := do
|
||||
return sizeOfDiff (← get).bvarsFound stateBefore.bvarsFound
|
||||
try
|
||||
let p ← NormalizePattern.normalizePattern p
|
||||
let stateAfter ← getThe NormalizePattern.State
|
||||
let numNewBVars ← getNumNewBVars
|
||||
if numNewBVars == 0 then
|
||||
trace[grind.debug.ematch.pattern] "skip, no new variables covered"
|
||||
return (← failed)
|
||||
/-
|
||||
Checks whether one of `e`s children subsumes it. We say a child `c` subsumes `e`
|
||||
1- `e` and `c` have the same new pattern variables. We say a pattern variable is new if it is not in `stateOld.bvarsFound`.
|
||||
2- `c` is not a support argument. See `NormalizePattern.getPatternSupportMask` for definition.
|
||||
3- `c` is not an offset pattern.
|
||||
4- `c` is not a bound variable.
|
||||
5- `c` is also a candidate.
|
||||
-/
|
||||
for arg in e.getAppArgs, argKind in argKinds do
|
||||
unless argKind.isSupport do
|
||||
unless arg.isFVar do
|
||||
unless isOffsetPattern? arg |>.isSome do
|
||||
if (← isPatternFnCandidate arg.getAppFn) then
|
||||
let pArg := arg.abstract (← read).xs
|
||||
set stateBefore
|
||||
discard <| NormalizePattern.normalizePattern pArg
|
||||
let numArgNewBVars ← getNumNewBVars
|
||||
if numArgNewBVars == numNewBVars then
|
||||
trace[grind.debug.ematch.pattern] "skip, subsumed by argument"
|
||||
return (← failed)
|
||||
set stateAfter
|
||||
return some p
|
||||
catch ex =>
|
||||
trace[grind.debug.ematch.pattern] "skip, exception during normalization{indentD ex.toMessageData}"
|
||||
failed
|
||||
|
||||
private partial def collect (e : Expr) : CollectorM Unit := do
|
||||
if (← get).done then return ()
|
||||
match e with
|
||||
| .app .. =>
|
||||
trace[grind.debug.ematch.pattern] "collect: {e}"
|
||||
let f := e.getAppFn
|
||||
let argKinds ← NormalizePattern.getPatternArgKinds f e.getAppNumArgs
|
||||
if (← isPatternFnCandidate f) then
|
||||
trace[grind.debug.ematch.pattern] "candidate: {e}"
|
||||
if let some p ← normalizePattern? e argKinds then
|
||||
addNewPattern p
|
||||
return ()
|
||||
let args := e.getAppArgs
|
||||
for arg in args, argKind in argKinds do
|
||||
unless isOffsetPattern? arg |>.isSome do
|
||||
trace[grind.debug.ematch.pattern] "arg: {arg}, support: {argKind.isSupport}"
|
||||
unless argKind.isSupport do
|
||||
collect arg
|
||||
| .forallE _ d b _ =>
|
||||
if (← pure e.isArrow <&&> isProp d <&&> isProp b) then
|
||||
collect d
|
||||
collect b
|
||||
| _ => return ()
|
||||
|
||||
register_builtin_option backward.grind.inferPattern : Bool := {
|
||||
defValue := true
|
||||
group := "backward compatibility"
|
||||
descr := "use old E-matching pattern inference"
|
||||
}
|
||||
|
||||
register_builtin_option backward.grind.checkInferPatternDiscrepancy : Bool := {
|
||||
defValue := false
|
||||
group := "backward compatibility"
|
||||
descr := "check whether old and new pattern inference procedures infer the same pattern"
|
||||
}
|
||||
|
||||
private def collectPatterns? (proof : Expr) (xs : Array Expr) (searchPlaces : Array Expr) (symPrios : SymbolPriorities) (minPrio : Nat)
|
||||
: MetaM (Option (List Expr × List HeadIndex)) := do
|
||||
let go : CollectorM (Option (List Expr)) := do
|
||||
let go (useOld : Bool): CollectorM (Option (List Expr)) := do
|
||||
for place in searchPlaces do
|
||||
trace[grind.debug.ematch.pattern] "place: {place}"
|
||||
let place ← preprocessPattern place
|
||||
collect place
|
||||
if useOld then
|
||||
OldCollector.collect place
|
||||
else
|
||||
collect place
|
||||
if (← get).done then
|
||||
return some ((← get).patterns.toList)
|
||||
return none
|
||||
let (some ps, s) ← go { proof, xs } |>.run' {} { symPrios, minPrio } |>.run {}
|
||||
| return none
|
||||
return some (ps, s.symbols.toList)
|
||||
let collect? (useOld : Bool) : MetaM (Option (List Expr × List HeadIndex)) := do
|
||||
let (some ps, s) ← go useOld { proof, xs } |>.run' {} { symPrios, minPrio } |>.run {}
|
||||
| return none
|
||||
return some (ps, s.symbols.toList)
|
||||
let useOld := backward.grind.inferPattern.get (← getOptions)
|
||||
if backward.grind.checkInferPatternDiscrepancy.get (← getOptions) then
|
||||
let oldResult? ← collect? (useOld := true)
|
||||
let newResult? ← collect? (useOld := false)
|
||||
let toPattern (result? : Option (List Expr × List HeadIndex)) : List MessageData :=
|
||||
let pat := result?.map (·.1) |>.getD []
|
||||
pat.map ppPattern
|
||||
if oldResult? != newResult? then
|
||||
logWarning m!"found discrepancy between old and new `grind` pattern inference procedures, old:{indentD (toPattern oldResult?)}\nnew:{indentD (toPattern newResult?)}"
|
||||
return if useOld then oldResult? else newResult?
|
||||
else
|
||||
collect? useOld
|
||||
|
||||
/--
|
||||
Tries to find a ground pattern to activate the theorem.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
// update me!
|
||||
#include "util/options.h"
|
||||
// Please update stage0
|
||||
namespace lean {
|
||||
options get_default_options() {
|
||||
options opts;
|
||||
|
|
|
|||
35
tests/lean/run/grind_pattern_inference_issue.lean
Normal file
35
tests/lean/run/grind_pattern_inference_issue.lean
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
set_option warn.sorry false
|
||||
namespace Test
|
||||
set_option backward.grind.inferPattern false -- Force new pattern inference procedure
|
||||
|
||||
inductive Vector (α : Type) : Nat → Type where
|
||||
| nil : Vector α 0
|
||||
| cons (x : α) (xs : Vector α n) : Vector α (n + 1)
|
||||
|
||||
def Vector.ofList (xs : List α) (h : xs.length = n) : Vector α n :=
|
||||
match n, xs with
|
||||
| 0, [] => .nil
|
||||
| (n + 1), x :: xs => .cons x (.ofList xs (by grind))
|
||||
|
||||
def Vector.toList (xs : Vector α n) : List α :=
|
||||
match xs with
|
||||
| .nil => []
|
||||
| .cons x xs => x :: xs.toList
|
||||
|
||||
/-- info: length_toList: [@toList #2 #1 #0] -/
|
||||
#guard_msgs (info) in
|
||||
@[grind?] theorem Vector.length_toList (xs : Vector α n) : xs.toList.length = n := by sorry
|
||||
|
||||
def wrapper (f : Nat → Nat → List α → List α) (h : ∀ n m xs, xs.length = n → (f n m xs).length = m) :
|
||||
(n m : Nat) → Vector α n → Vector α m :=
|
||||
fun n m xs => Vector.ofList (f n m xs.toList) (by grind) -- apply h; apply Vector.length_toList) -- fails here: (by grind)
|
||||
|
||||
/--
|
||||
warning: found discrepancy between old and new `grind` pattern inference procedures, old:
|
||||
[@List.length #2 (@toList _ #1 #0)]
|
||||
new:
|
||||
[@toList #2 #1 #0]
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option backward.grind.checkInferPatternDiscrepancy true in
|
||||
@[grind] theorem Vector.length_toList' (xs : Vector α n) : xs.toList.length = n := by sorry
|
||||
Loading…
Add table
Reference in a new issue