diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index 6773e8eb31..0c3e7e0ede 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -1009,6 +1009,7 @@ where | .bvar idx => modify fun s => if s.contains idx then s else idx :: s | _ => return () +namespace OldCollector private def diff (s : List Nat) (found : Std.HashSet Nat) : List Nat := if found.isEmpty then s else s.filter fun x => !found.contains x @@ -1069,19 +1070,129 @@ private partial def collect (e : Expr) : CollectorM Unit := do collect b | _ => return () +end OldCollector + +private def sizeOfDiff (s₁ s₂ : Std.HashSet Nat) : Nat := + s₂.fold (init := s₁.size) fun num idx => + if s₁.contains idx then num - 1 else num + +/-- +Normalizes `e` if it qualifies as a candidate pattern, and returns +`some p` where `p` is the normalized pattern. + +`argKinds == NormalizePattern.getPatternArgKinds e.getAppFn e.getAppNumArgs` +-/ +private def normalizePattern? (e : Expr) (argKinds : Array NormalizePattern.PatternArgKind) : CollectorM (Option Expr) := do + let p := e.abstract (← read).xs + unless p.hasLooseBVars do + trace[grind.debug.ematch.pattern] "skip, does not contain pattern variables" + return none + -- Normalization state before normalizing `e` + let stateBefore ← getThe NormalizePattern.State + let failed : CollectorM (Option Expr) := do + set stateBefore + return none + -- Returns the number of new variables with respect to `saved` + let getNumNewBVars : NormalizePattern.M Nat := do + return sizeOfDiff (← get).bvarsFound stateBefore.bvarsFound + try + let p ← NormalizePattern.normalizePattern p + let stateAfter ← getThe NormalizePattern.State + let numNewBVars ← getNumNewBVars + if numNewBVars == 0 then + trace[grind.debug.ematch.pattern] "skip, no new variables covered" + return (← failed) + /- + Checks whether one of `e`s children subsumes it. We say a child `c` subsumes `e` + 1- `e` and `c` have the same new pattern variables. We say a pattern variable is new if it is not in `stateOld.bvarsFound`. + 2- `c` is not a support argument. See `NormalizePattern.getPatternSupportMask` for definition. + 3- `c` is not an offset pattern. + 4- `c` is not a bound variable. + 5- `c` is also a candidate. + -/ + for arg in e.getAppArgs, argKind in argKinds do + unless argKind.isSupport do + unless arg.isFVar do + unless isOffsetPattern? arg |>.isSome do + if (← isPatternFnCandidate arg.getAppFn) then + let pArg := arg.abstract (← read).xs + set stateBefore + discard <| NormalizePattern.normalizePattern pArg + let numArgNewBVars ← getNumNewBVars + if numArgNewBVars == numNewBVars then + trace[grind.debug.ematch.pattern] "skip, subsumed by argument" + return (← failed) + set stateAfter + return some p + catch ex => + trace[grind.debug.ematch.pattern] "skip, exception during normalization{indentD ex.toMessageData}" + failed + +private partial def collect (e : Expr) : CollectorM Unit := do + if (← get).done then return () + match e with + | .app .. => + trace[grind.debug.ematch.pattern] "collect: {e}" + let f := e.getAppFn + let argKinds ← NormalizePattern.getPatternArgKinds f e.getAppNumArgs + if (← isPatternFnCandidate f) then + trace[grind.debug.ematch.pattern] "candidate: {e}" + if let some p ← normalizePattern? e argKinds then + addNewPattern p + return () + let args := e.getAppArgs + for arg in args, argKind in argKinds do + unless isOffsetPattern? arg |>.isSome do + trace[grind.debug.ematch.pattern] "arg: {arg}, support: {argKind.isSupport}" + unless argKind.isSupport do + collect arg + | .forallE _ d b _ => + if (← pure e.isArrow <&&> isProp d <&&> isProp b) then + collect d + collect b + | _ => return () + +register_builtin_option backward.grind.inferPattern : Bool := { + defValue := true + group := "backward compatibility" + descr := "use old E-matching pattern inference" +} + +register_builtin_option backward.grind.checkInferPatternDiscrepancy : Bool := { + defValue := false + group := "backward compatibility" + descr := "check whether old and new pattern inference procedures infer the same pattern" +} + private def collectPatterns? (proof : Expr) (xs : Array Expr) (searchPlaces : Array Expr) (symPrios : SymbolPriorities) (minPrio : Nat) : MetaM (Option (List Expr × List HeadIndex)) := do - let go : CollectorM (Option (List Expr)) := do + let go (useOld : Bool): CollectorM (Option (List Expr)) := do for place in searchPlaces do trace[grind.debug.ematch.pattern] "place: {place}" let place ← preprocessPattern place - collect place + if useOld then + OldCollector.collect place + else + collect place if (← get).done then return some ((← get).patterns.toList) return none - let (some ps, s) ← go { proof, xs } |>.run' {} { symPrios, minPrio } |>.run {} - | return none - return some (ps, s.symbols.toList) + let collect? (useOld : Bool) : MetaM (Option (List Expr × List HeadIndex)) := do + let (some ps, s) ← go useOld { proof, xs } |>.run' {} { symPrios, minPrio } |>.run {} + | return none + return some (ps, s.symbols.toList) + let useOld := backward.grind.inferPattern.get (← getOptions) + if backward.grind.checkInferPatternDiscrepancy.get (← getOptions) then + let oldResult? ← collect? (useOld := true) + let newResult? ← collect? (useOld := false) + let toPattern (result? : Option (List Expr × List HeadIndex)) : List MessageData := + let pat := result?.map (·.1) |>.getD [] + pat.map ppPattern + if oldResult? != newResult? then + logWarning m!"found discrepancy between old and new `grind` pattern inference procedures, old:{indentD (toPattern oldResult?)}\nnew:{indentD (toPattern newResult?)}" + return if useOld then oldResult? else newResult? + else + collect? useOld /-- Tries to find a ground pattern to activate the theorem. diff --git a/stage0/src/stdlib_flags.h b/stage0/src/stdlib_flags.h index fc33b085a4..d3bd0d23fb 100644 --- a/stage0/src/stdlib_flags.h +++ b/stage0/src/stdlib_flags.h @@ -1,5 +1,5 @@ +// update me! #include "util/options.h" -// Please update stage0 namespace lean { options get_default_options() { options opts; diff --git a/tests/lean/run/grind_pattern_inference_issue.lean b/tests/lean/run/grind_pattern_inference_issue.lean new file mode 100644 index 0000000000..4011c7e227 --- /dev/null +++ b/tests/lean/run/grind_pattern_inference_issue.lean @@ -0,0 +1,35 @@ +set_option warn.sorry false +namespace Test +set_option backward.grind.inferPattern false -- Force new pattern inference procedure + +inductive Vector (α : Type) : Nat → Type where + | nil : Vector α 0 + | cons (x : α) (xs : Vector α n) : Vector α (n + 1) + +def Vector.ofList (xs : List α) (h : xs.length = n) : Vector α n := + match n, xs with + | 0, [] => .nil + | (n + 1), x :: xs => .cons x (.ofList xs (by grind)) + +def Vector.toList (xs : Vector α n) : List α := + match xs with + | .nil => [] + | .cons x xs => x :: xs.toList + +/-- info: length_toList: [@toList #2 #1 #0] -/ +#guard_msgs (info) in +@[grind?] theorem Vector.length_toList (xs : Vector α n) : xs.toList.length = n := by sorry + +def wrapper (f : Nat → Nat → List α → List α) (h : ∀ n m xs, xs.length = n → (f n m xs).length = m) : + (n m : Nat) → Vector α n → Vector α m := + fun n m xs => Vector.ofList (f n m xs.toList) (by grind) -- apply h; apply Vector.length_toList) -- fails here: (by grind) + +/-- +warning: found discrepancy between old and new `grind` pattern inference procedures, old: + [@List.length #2 (@toList _ #1 #0)] +new: + [@toList #2 #1 #0] +-/ +#guard_msgs in +set_option backward.grind.checkInferPatternDiscrepancy true in +@[grind] theorem Vector.length_toList' (xs : Vector α n) : xs.toList.length = n := by sorry