diff --git a/src/Init/Grind/Norm.lean b/src/Init/Grind/Norm.lean index 2673caa837..d9ba8051ed 100644 --- a/src/Init/Grind/Norm.lean +++ b/src/Init/Grind/Norm.lean @@ -46,6 +46,12 @@ attribute [grind_norm] not_false_eq_true theorem imp_eq (p q : Prop) : (p → q) = (¬ p ∨ q) := by by_cases p <;> by_cases q <;> simp [*] +@[grind_norm] theorem true_imp_eq (p : Prop) : (True → p) = p := by simp +@[grind_norm] theorem false_imp_eq (p : Prop) : (False → p) = True := by simp +@[grind_norm] theorem imp_true_eq (p : Prop) : (p → True) = True := by simp +@[grind_norm] theorem imp_false_eq (p : Prop) : (p → False) = ¬p := by simp +@[grind_norm] theorem imp_self_eq (p : Prop) : (p → p) = True := by simp + -- And @[grind_norm↓] theorem not_and (p q : Prop) : (¬(p ∧ q)) = (¬p ∨ ¬q) := by by_cases p <;> by_cases q <;> simp [*] diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index 3b427eca41..acf370f11f 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -170,11 +170,43 @@ private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTh initial := {} } +/-- +Symbols with built-in support in `grind` are unsuitable as pattern candidates for E-matching. +This is because `grind` performs normalization operations and uses specialized data structures +to implement these symbols, which may interfere with E-matching behavior. +-/ -- TODO: create attribute? private def forbiddenDeclNames := #[``Eq, ``HEq, ``Iff, ``And, ``Or, ``Not] private def isForbidden (declName : Name) := forbiddenDeclNames.contains declName +/-- +Auxiliary function to expand a pattern containing forbidden application symbols +into a multi-pattern. + +This function enhances the usability of the `[grind =]` attribute by automatically handling +forbidden pattern symbols. For example, consider the following theorem tagged with this attribute: +``` +getLast?_eq_some_iff {xs : List α} {a : α} : xs.getLast? = some a ↔ ∃ ys, xs = ys ++ [a] +``` +Here, the selected pattern is `xs.getLast? = some a`, but `Eq` is a forbidden pattern symbol. +Instead of producing an error, this function converts the pattern into a multi-pattern, +allowing the attribute to be used conveniently. + +The function recursively expands patterns with forbidden symbols by splitting them +into their sub-components. If the pattern does not contain forbidden symbols, +it is returned as-is. +-/ +partial def splitWhileForbidden (pat : Expr) : List Expr := + match_expr pat with + | Not p => splitWhileForbidden p + | And p₁ p₂ => splitWhileForbidden p₁ ++ splitWhileForbidden p₂ + | Or p₁ p₂ => splitWhileForbidden p₁ ++ splitWhileForbidden p₂ + | Eq _ lhs rhs => splitWhileForbidden lhs ++ splitWhileForbidden rhs + | Iff lhs rhs => splitWhileForbidden lhs ++ splitWhileForbidden rhs + | HEq _ lhs _ rhs => splitWhileForbidden lhs ++ splitWhileForbidden rhs + | _ => [pat] + private def dontCare := mkConst (Name.mkSimple "[grind_dontcare]") def mkGroundPattern (e : Expr) : Expr := @@ -468,7 +500,8 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : | _ => throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}" let pat := if useLhs then lhs else rhs let pat ← preprocessPattern pat normalizePattern - return (xs.size, [pat.abstract xs]) + let pats := splitWhileForbidden (pat.abstract xs) + return (xs.size, pats) mkEMatchTheoremCore origin levelParams numParams proof patterns /-- diff --git a/tests/lean/run/grind_eq_pattern.lean b/tests/lean/run/grind_eq_pattern.lean new file mode 100644 index 0000000000..179a6d6496 --- /dev/null +++ b/tests/lean/run/grind_eq_pattern.lean @@ -0,0 +1,22 @@ +attribute [grind] List.append_ne_nil_of_left_ne_nil +attribute [grind] List.append_ne_nil_of_right_ne_nil +/-- +info: [grind.ematch.pattern] List.getLast?_eq_some_iff: [@List.getLast? #2 #1, @some ? #0] +-/ +#guard_msgs (info) in +set_option trace.grind.ematch.pattern true in +attribute [grind =] List.getLast?_eq_some_iff + +/-- +info: [grind.assert] xs.getLast? = b? +[grind.assert] b? = some 10 +[grind.assert] xs = [] +[grind.assert] (xs.getLast? = some 10) = ∃ ys, xs = ys ++ [10] +[grind.assert] xs = w ++ [10] +[grind.assert] ¬w = [] → ¬w ++ [10] = [] +[grind.assert] ¬w ++ [10] = [] +-/ +#guard_msgs (info) in +set_option trace.grind.assert true in +example (xs : List Nat) : xs.getLast? = b? → b? = some 10 → xs ≠ [] := by + grind