diff --git a/src/Init/Grind/Tactics.lean b/src/Init/Grind/Tactics.lean index 71b9455a93..ea7e0b657c 100644 --- a/src/Init/Grind/Tactics.lean +++ b/src/Init/Grind/Tactics.lean @@ -6,6 +6,18 @@ Authors: Leonardo de Moura prelude import Init.Tactics +namespace Lean.Parser.Attr + +syntax grindEq := "=" +syntax grindEqBoth := "_=_" +syntax grindEqRhs := "=_" +syntax grindBwd := "←" +syntax grindFwd := "→" + +syntax (name := grind) "grind" (grindEq <|> grindBwd <|> grindFwd <|> grindEqBoth <|> grindEqRhs)? : attr + +end Lean.Parser.Attr + namespace Lean.Grind /-- The configuration for `grind`. diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index 01ee792bbe..022695fb35 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -34,6 +34,7 @@ builtin_initialize registerTraceClass `grind.eqc builtin_initialize registerTraceClass `grind.internalize builtin_initialize registerTraceClass `grind.ematch builtin_initialize registerTraceClass `grind.ematch.pattern +builtin_initialize registerTraceClass `grind.ematch.pattern.search builtin_initialize registerTraceClass `grind.ematch.instance builtin_initialize registerTraceClass `grind.ematch.instance.assignment builtin_initialize registerTraceClass `grind.issues diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index 99e8a0a518..9605c8e116 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ prelude import Init.Grind.Util +import Init.Grind.Tactics import Lean.HeadIndex import Lean.PrettyPrinter import Lean.Util.FoldConsts @@ -218,16 +219,18 @@ private def getPatternFn? (pattern : Expr) : Option Expr := /-- Returns a bit-mask `mask` s.t. `mask[i]` is true if the the corresponding argument is -- a type or type former, or +- a type (that is not a proposition) or type former, or - a proof, or - an instance implicit argument When `mask[i]`, we say the corresponding argument is a "support" argument. -/ -private def getPatternFunMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) := do +def getPatternSupportMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) := do forallBoundedTelescope (← inferType f) numArgs fun xs _ => do xs.mapM fun x => do - if (← isTypeFormer x <||> isProof x) then + if (← isProp x) then + return false + else if (← isTypeFormer x <||> isProof x) then return true else return (← x.fvarId!.getDecl).binderInfo matches .instImplicit @@ -246,7 +249,7 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do assert! f.isConst || f.isFVar saveSymbol f.toHeadIndex let mut args := pattern.getAppArgs - let supportMask ← getPatternFunMask f args.size + let supportMask ← getPatternSupportMask f args.size for i in [:args.size] do let arg := args[i]! let isSupport := supportMask[i]?.getD false @@ -278,6 +281,9 @@ def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex × Std.Hash let (patterns, s) ← patterns.mapM go |>.run {} return (patterns, s.symbols.toList, s.bvarsFound) +def normalizePattern (e : Expr) : M Expr := do + go e + end NormalizePattern /-- @@ -402,26 +408,50 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) : msg := msg ++ m!"{x} : {← inferType x}" addMessageContextFull msg +/-- +Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns. +Pattern variables are represented using de Bruijn indices. +-/ +def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) : MetaM EMatchTheorem := do + let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns + trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}" + if let .missing pos ← checkCoverage proof numParams bvarFound then + let pats : MessageData := m!"{patterns.map ppPattern}" + throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}" + return { + proof, patterns, numParams, symbols + levelParams, origin + } + +private def getProofFor (declName : Name) : CoreM Expr := do + let .thmInfo info ← getConstInfo declName + | throwError "`{declName}` is not a theorem" + let us := info.levelParams.map mkLevelParam + return mkConst declName us + /-- Creates an E-matching theorem for `declName` with `numParams` parameters, and the given set of patterns. Pattern variables are represented using de Bruijn indices. -/ def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM EMatchTheorem := do - let .thmInfo info ← getConstInfo declName - | throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic" - let us := info.levelParams.map mkLevelParam - let proof := mkConst declName us - let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns - assert! symbols.all fun s => s matches .const _ - trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}" - if let .missing pos ← checkCoverage proof numParams bvarFound then - let pats : MessageData := m!"{patterns.map ppPattern}" - throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}" - return { - proof, patterns, numParams, symbols - levelParams := #[] - origin := .decl declName - } + mkEMatchTheoremCore (.decl declName) #[] numParams (← getProofFor declName) patterns + +/-- +Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs = rhs`, +creates an E-matching pattern for it using `addEMatchTheorem n [lhs]` +If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the pattern. +-/ +def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) : MetaM EMatchTheorem := do + let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do + let (lhs, rhs) ← match_expr type with + | Eq _ lhs rhs => pure (lhs, rhs) + | Iff lhs rhs => pure (lhs, rhs) + | HEq _ lhs _ rhs => pure (lhs, rhs) + | _ => throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}" + let pat := if useLhs then lhs else rhs + let pat ← preprocessPattern pat normalizePattern + return (xs.size, [pat.abstract xs]) + mkEMatchTheoremCore origin levelParams numParams proof patterns /-- Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`, @@ -430,17 +460,8 @@ creates an E-matching pattern for it using `addEMatchTheorem n [lhs]` If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the pattern. -/ -def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) : MetaM EMatchTheorem := do - let info ← getConstInfo declName - let (numParams, patterns) ← forallTelescopeReducing info.type fun xs type => do - let lhs ← match_expr type with - | Eq _ lhs _ => pure lhs - | Iff lhs _ => pure lhs - | HEq _ lhs _ _ => pure lhs - | _ => throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}" - let lhs ← preprocessPattern lhs normalizePattern - return (xs.size, [lhs.abstract xs]) - mkEMatchTheorem declName numParams patterns +def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) : MetaM EMatchTheorem := do + mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs /-- Adds an E-matching theorem to the environment. @@ -460,18 +481,177 @@ def addEMatchEqTheorem (declName : Name) : MetaM Unit := do def getEMatchTheorems : CoreM EMatchTheorems := return ematchTheoremsExt.getState (← getEnv) -private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do +private inductive TheoremKind where + | eqLhs | eqRhs | eqBoth | fwd | bwd | default + deriving Inhabited, BEq + +private def TheoremKind.toAttribute : TheoremKind → String + | .eqLhs => "[grind =]" + | .eqRhs => "[grind =_]" + | .eqBoth => "[grind _=_]" + | .fwd => "[grind →]" + | .bwd => "[grind ←]" + | .default => "[grind]" + +private def TheoremKind.explainFailure : TheoremKind → String + | .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion" + | .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion" + | .eqBoth => unreachable! -- eqBoth is a macro + | .fwd => "failed to find patterns in the antecedents of the theorem" + | .bwd => "failed to find patterns in the theorem's conclusion" + | .default => "failed to find patterns" + +/-- Returns the types of `xs` that are propositions. -/ +private def getPropTypes (xs : Array Expr) : MetaM (Array Expr) := + xs.filterMapM fun x => do + let type ← inferType x + if (← isProp type) then return some type else return none + +/-- State for the (pattern) `CollectorM` monad -/ +private structure Collector.State where + /-- Pattern found so far. -/ + patterns : Array Expr := #[] + done : Bool := false + +private structure Collector.Context where + proof : Expr + xs : Array Expr + +/-- Monad for collecting patterns for a theorem. -/ +private abbrev CollectorM := ReaderT Collector.Context $ StateRefT Collector.State NormalizePattern.M + +/-- Similar to `getPatternFn?`, but operates on expressions that do not contain loose de Bruijn variables. -/ +private def isPatternFnCandidate (f : Expr) : CollectorM Bool := do + match f with + | .const declName _ => return !isForbidden declName + | .fvar .. => return !(← read).xs.contains f + | _ => return false + +private def addNewPattern (p : Expr) : CollectorM Unit := do + trace[grind.ematch.pattern.search] "found pattern: {ppPattern p}" + let bvarsFound := (← getThe NormalizePattern.State).bvarsFound + let done := (← checkCoverage (← read).proof (← read).xs.size bvarsFound) matches .ok + if done then + trace[grind.ematch.pattern.search] "found full coverage" + modify fun s => { s with patterns := s.patterns.push p, done } + +private partial def collect (e : Expr) : CollectorM Unit := do + if (← get).done then return () + match e with + | .app .. => + let f := e.getAppFn + if (← isPatternFnCandidate f) then + let saved ← getThe NormalizePattern.State + try + trace[grind.ematch.pattern.search] "candidate: {e}" + let p := e.abstract (← read).xs + unless p.hasLooseBVars do + trace[grind.ematch.pattern.search] "skip, does not contain pattern variables" + return () + let p ← NormalizePattern.normalizePattern p + if saved.bvarsFound.size < (← getThe NormalizePattern.State).bvarsFound.size then + addNewPattern p + return () + trace[grind.ematch.pattern.search] "skip, no new variables covered" + -- restore state and continue search + set saved + catch _ => + -- restore state and continue search + trace[grind.ematch.pattern.search] "skip, exception during normalization" + set saved + let args := e.getAppArgs + for arg in args, flag in (← NormalizePattern.getPatternSupportMask f args.size) do + unless flag do + collect arg + | .forallE _ d b _ => + if (← pure e.isArrow <&&> isProp d <&&> isProp b) then + collect d + collect b + | _ => return () + +private def collectPatterns? (proof : Expr) (xs : Array Expr) (searchPlaces : Array Expr) : MetaM (Option (List Expr × List HeadIndex)) := do + let go : CollectorM (Option (List Expr)) := do + for place in searchPlaces do + let place ← preprocessPattern place + collect place + if (← get).done then + return some ((← get).patterns.toList) + return none + let (some ps, s) ← go { proof, xs } |>.run' {} |>.run {} + | return none + return some (ps, s.symbols.toList) + +private def mkEMatchTheoremWithKind? (origin : Origin) (levelParams : Array Name) (proof : Expr) (kind : TheoremKind) : MetaM (Option EMatchTheorem) := do + if kind == .eqLhs then + return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := false) (useLhs := true)) + else if kind == .eqRhs then + return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := false) (useLhs := false)) + let type ← inferType proof + forallTelescopeReducing type fun xs type => do + let searchPlaces ← match kind with + | .fwd => + let ps ← getPropTypes xs + if ps.isEmpty then + throwError "invalid `grind` forward theorem, theorem `{← origin.pp}` does not have proposional hypotheses" + pure ps + | .bwd => pure #[type] + | .default => pure <| #[type] ++ (← getPropTypes xs) + | _ => unreachable! + go xs searchPlaces +where + go (xs : Array Expr) (searchPlaces : Array Expr) : MetaM (Option EMatchTheorem) := do + let some (patterns, symbols) ← collectPatterns? proof xs searchPlaces + | return none + let numParams := xs.size + trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}" + return some { + proof, patterns, numParams, symbols + levelParams, origin + } + +private def getKind (stx : Syntax) : TheoremKind := + if stx[1].isNone then + .default + else if stx[1][0].getKind == ``Parser.Attr.grindEq then + .eqLhs + else if stx[1][0].getKind == ``Parser.Attr.grindFwd then + .fwd + else if stx[1][0].getKind == ``Parser.Attr.grindEqRhs then + .eqRhs + else if stx[1][0].getKind == ``Parser.Attr.grindEqBoth then + .eqBoth + else + .bwd + +private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (useLhs := true) : MetaM Unit := do if (← getConstInfo declName).isTheorem then - ematchTheoremsExt.add (← mkEMatchEqTheorem declName) attrKind + ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs)) attrKind else if let some eqns ← getEqnsFor? declName then + unless useLhs do + throwError "`{declName}` is a definition, you must only use the left-hand side for extracting patterns" for eqn in eqns do ematchTheoremsExt.add (← mkEMatchEqTheorem eqn) attrKind else throwError "`[grind_eq]` attribute can only be applied to equational theorems or function definitions" +private def addGrindAttr (declName : Name) (attrKind : AttributeKind) (thmKind : TheoremKind) : MetaM Unit := do + if thmKind == .eqLhs then + addGrindEqAttr declName attrKind (useLhs := true) + else if thmKind == .eqRhs then + addGrindEqAttr declName attrKind (useLhs := false) + else if thmKind == .eqBoth then + addGrindEqAttr declName attrKind (useLhs := true) + addGrindEqAttr declName attrKind (useLhs := false) + else if !(← getConstInfo declName).isTheorem then + addGrindEqAttr declName attrKind + else + let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind + | throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command" + ematchTheoremsExt.add thm attrKind + builtin_initialize registerBuiltinAttribute { - name := `grind_eq + name := `grind descr := "The `[grind_eq]` attribute is used to annotate equational theorems and functions.\ When applied to an equational theorem, it marks the theorem for use in heuristic instantiations by the `grind` tactic.\ @@ -480,8 +660,8 @@ builtin_initialize For example, if a theorem `@[grind_eq] theorem foo_idempotent : foo (foo x) = foo x` is annotated,\ `grind` will add an instance of this theorem to the local context whenever it encounters the pattern `foo (foo x)`." applicationTime := .afterCompilation - add := fun declName _ attrKind => - addGrindEqAttr declName attrKind |>.run' {} + add := fun declName stx attrKind => do + addGrindAttr declName attrKind (getKind stx) |>.run' {} } end Lean.Meta.Grind diff --git a/tests/lean/run/grind_ematch1.lean b/tests/lean/run/grind_ematch1.lean index 5bce28cc44..9523882c48 100644 --- a/tests/lean/run/grind_ematch1.lean +++ b/tests/lean/run/grind_ematch1.lean @@ -69,3 +69,148 @@ info: [grind.ematch.instance] Rtrans: R a d → R d e → R a e #guard_msgs (info) in example : R a b → R b c → R c d → R d e → R a d := by grind + + +namespace using_grind_fwd + +opaque S : Nat → Nat → Prop + +/-- +error: `@[grind →] theorem using_grind_fwd.StransBad` failed to find patterns in the antecedents of the theorem, consider using different options or the `grind_pattern` command +-/ +#guard_msgs (error) in +@[grind→] theorem StransBad (a b c d : Nat) : S a b ∨ R a b → S b c → S a c ∧ S b d := sorry + + +set_option trace.grind.ematch.pattern.search true in +/-- +info: [grind.ematch.pattern.search] candidate: S a b +[grind.ematch.pattern.search] found pattern: S #4 #3 +[grind.ematch.pattern.search] candidate: R a b +[grind.ematch.pattern.search] skip, no new variables covered +[grind.ematch.pattern.search] candidate: S b c +[grind.ematch.pattern.search] found pattern: S #3 #2 +[grind.ematch.pattern.search] found full coverage +[grind.ematch.pattern] Strans: [S #4 #3, S #3 #2] +-/ +#guard_msgs (info) in +@[grind→] theorem Strans (a b c : Nat) : S a b ∨ R a b → S b c → S a c := sorry + +/-- +info: [grind.ematch.instance] Strans: S a b ∨ R a b → S b c → S a c +-/ +#guard_msgs (info) in +example : S a b → S b c → S a c := by + grind + +end using_grind_fwd + +namespace using_grind_bwd + +opaque P : Nat → Prop +opaque Q : Nat → Prop +opaque f : Nat → Nat → Nat + +/-- +info: [grind.ematch.pattern] pqf: [P (f #2 #1)] +-/ +#guard_msgs (info) in +@[grind←] theorem pqf : Q x → P (f x y) := sorry + +/-- +info: [grind.ematch.instance] pqf: Q a → P (f a b) +-/ +#guard_msgs (info) in +example : Q 0 → Q 1 → Q 2 → Q 3 → ¬ P (f a b) → a = 1 → False := by + grind + +end using_grind_bwd + +namespace using_grind_fwd2 + +opaque P : Nat → Prop +opaque Q : Nat → Prop +opaque f : Nat → Nat → Nat + +/-- +error: `@[grind →] theorem using_grind_fwd2.pqfBad` failed to find patterns in the antecedents of the theorem, consider using different options or the `grind_pattern` command +-/ +#guard_msgs (error) in +@[grind→] theorem pqfBad : Q x → P (f x y) := sorry + +/-- +info: [grind.ematch.pattern] pqf: [Q #1] +-/ +#guard_msgs (info) in +@[grind→] theorem pqf : Q x → P (f x x) := sorry + +/-- +info: [grind.ematch.instance] pqf: Q 3 → P (f 3 3) +[grind.ematch.instance] pqf: Q 2 → P (f 2 2) +[grind.ematch.instance] pqf: Q 1 → P (f 1 1) +[grind.ematch.instance] pqf: Q 0 → P (f 0 0) +-/ +#guard_msgs (info) in +example : Q 0 → Q 1 → Q 2 → Q 3 → ¬ P (f a a) → a = 1 → False := by + grind + +end using_grind_fwd2 + +namespace using_grind_mixed + +opaque P : Nat → Nat → Prop +opaque Q : Nat → Nat → Prop + +/-- +error: `@[grind →] theorem using_grind_mixed.pqBad1` failed to find patterns in the antecedents of the theorem, consider using different options or the `grind_pattern` command +-/ +#guard_msgs (error) in +@[grind→] theorem pqBad1 : P x y → Q x z := sorry + +/-- +error: `@[grind ←] theorem using_grind_mixed.pqBad2` failed to find patterns in the theorem's conclusion, consider using different options or the `grind_pattern` command +-/ +#guard_msgs (error) in +@[grind←] theorem pqBad2 : P x y → Q x z := sorry + + +/-- +info: [grind.ematch.pattern] pqBad: [Q #3 #1, P #3 #2] +-/ +#guard_msgs (info) in +@[grind] theorem pqBad : P x y → Q x z := sorry + +example : P a b → Q a c := by + grind + +end using_grind_mixed + + +namespace using_grind_rhs + +opaque f : Nat → Nat +opaque g : Nat → Nat → Nat + +/-- +info: [grind.ematch.pattern] fq: [g #0 (f #0)] +-/ +#guard_msgs (info) in +@[grind =_] +theorem fq : f x = g x (f x) := sorry + +end using_grind_rhs + +namespace using_grind_lhs_rhs + +opaque f : Nat → Nat +opaque g : Nat → Nat → Nat + +/-- +info: [grind.ematch.pattern] fq: [f #0] +[grind.ematch.pattern] fq: [g #0 (g #0 #0)] +-/ +#guard_msgs (info) in +@[grind _=_] +theorem fq : f x = g x (g x x) := sorry + +end using_grind_lhs_rhs diff --git a/tests/lean/run/grind_eq.lean b/tests/lean/run/grind_eq.lean index b8d7c22d62..333aab7593 100644 --- a/tests/lean/run/grind_eq.lean +++ b/tests/lean/run/grind_eq.lean @@ -1,6 +1,8 @@ opaque g : Nat → Nat -@[grind_eq] def f (a : Nat) := +set_option trace.Meta.debug true + +@[grind] def f (a : Nat) := match a with | 0 => 10 | x+1 => g (f x) @@ -21,7 +23,7 @@ info: [grind.assert] f (y + 1) = a example : f (y + 1) = a → a = g (f y):= by grind -@[grind_eq] def app (xs ys : List α) := +@[grind] def app (xs ys : List α) := match xs with | [] => ys | x::xs => x :: app xs ys @@ -43,7 +45,7 @@ example : app [1, 2] ys = xs → xs = 1::2::ys := by opaque p : Nat → Nat → Prop opaque q : Nat → Prop -@[grind_eq] theorem pq : p x x ↔ q x := by sorry +@[grind =] theorem pq : p x x ↔ q x := by sorry /-- info: [grind.assert] p a a @@ -58,7 +60,7 @@ example : p a a → q a := by opaque appV (xs : Vector α n) (ys : Vector α m) : Vector α (n + m) := Vector.append xs ys -@[grind_eq] +@[grind =] theorem appV_assoc (a : Vector α n) (b : Vector α m) (c : Vector α n') : HEq (appV a (appV b c)) (appV (appV a b) c) := sorry diff --git a/tests/lean/run/grind_pattern1.lean b/tests/lean/run/grind_pattern1.lean index f41e3bb96f..d827e1efa0 100644 --- a/tests/lean/run/grind_pattern1.lean +++ b/tests/lean/run/grind_pattern1.lean @@ -20,7 +20,7 @@ grind_pattern List.mem_concat_self => a ∈ xs ++ [a] def foo (x : Nat) := x + x /-- -error: `foo` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic +error: `foo` is not a theorem -/ #guard_msgs in grind_pattern foo => x + x