From 16740a15405f51418d159fc5b76d5df5f5e6424c Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 27 Nov 2025 10:05:47 -0800 Subject: [PATCH] feat: some `grind_pattern` constraints (#11405) This PR implements the following `grind_pattern` constraints: ```lean grind_pattern fax => f x where depth x < 2 grind_pattern fax => f x where is_ground x grind_pattern fax => f x where size x < 5 grind_pattern fax => f x where gen < 2 grind_pattern fax => f x where max_insts < 4 grind_pattern gax => g as where as =?= _ :: _ ``` --- src/Lean/Elab/Tactic/Grind/Main.lean | 5 +- src/Lean/Meta/Tactic/Grind/EMatch.lean | 64 ++++++++-- src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean | 2 +- src/Lean/Meta/Tactic/Grind/Parser.lean | 4 +- src/Lean/Meta/Tactic/Grind/Types.lean | 3 + tests/lean/run/grind_pattern_cnstr_2.lean | 110 ++++++++++++++++++ 6 files changed, 173 insertions(+), 15 deletions(-) create mode 100644 tests/lean/run/grind_pattern_cnstr_2.lean diff --git a/src/Lean/Elab/Tactic/Grind/Main.lean b/src/Lean/Elab/Tactic/Grind/Main.lean index d4c0ed312b..e341cfcd69 100644 --- a/src/Lean/Elab/Tactic/Grind/Main.lean +++ b/src/Lean/Elab/Tactic/Grind/Main.lean @@ -100,8 +100,7 @@ where else if kind == ``defEq then elabDefEq xs cnstr[0] cnstr[2] else if kind == ``genLt then - let (_, lhs) ← findLHS xs cnstr[1] - return .genLt lhs cnstr[3].toNat + return .genLt cnstr[2].toNat else if kind == ``sizeLt then let (_, lhs) ← findLHS xs cnstr[1] return .sizeLt lhs cnstr[3].toNat @@ -109,7 +108,7 @@ where let (_, lhs) ← findLHS xs cnstr[1] return .depthLt lhs cnstr[3].toNat else if kind == ``maxInsts then - return .maxInsts cnstr[1].toNat + return .maxInsts cnstr[2].toNat else if kind == ``isValue then let (_, lhs) ← findLHS xs cnstr[1] return .isValue lhs false diff --git a/src/Lean/Meta/Tactic/Grind/EMatch.lean b/src/Lean/Meta/Tactic/Grind/EMatch.lean index fdb23a5e38..fc776a05d8 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatch.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatch.lean @@ -637,13 +637,17 @@ private abbrev withFreshNGen (x : M α) : M α := do finally setNGen ngen -/-- -Checks constraints of the form `lhs =/= rhs`. --/ -private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do +private def getLHS (args : Array Expr) (lhs : Nat) : MetaM Expr := do unless lhs < args.size do throwError "`grind` internal error, invalid variable in `grind_pattern` constraint" - let lhs := args[args.size - lhs - 1]! + instantiateMVars args[args.size - lhs - 1]! + +/-- +Checks constraints of the form `lhs =/= rhs` and `lhs =?= rhs`. +`expectedResult` is `true` if `lhs` and `rhs` should be definitionally equal. +-/ +private def checkDefEq (expectedResult : Bool) (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do + let lhs ← getLHS args lhs /- **Note**: We first instantiate the theorem variables and universes occurring in `rhs`. -/ let rhsExpr := rhs.expr.instantiateRev args let rhsExpr := rhsExpr.instantiateLevelParams levelParams us @@ -657,7 +661,38 @@ private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Ar let (_, _, rhsExpr) ← lambdaMetaTelescope rhsExpr (some rhs.numMVars) /- **Note**: We used the guarded version to ensure type errors will not interrupt `grind`. -/ let defEq ← isDefEqGuarded lhs rhsExpr - return !defEq + return defEq == expectedResult + +/-- +Helper function for checking grind pattern constraints of the form `size e < threshold` +Implicit arguments and type information in lambdas and let-expressions are ignored. +-/ +partial def checkSize (e : Expr) (threshold : Nat) : MetaM Bool := + return (← go e |>.run |>.run 0).1.isSome +where + go (e : Expr) : OptionT (StateT Nat MetaM) Unit := do + guard ((← get) < threshold) + modify (·+1) + match e with + | .forallE _ d b _ => go d; go b + | .lam _ _ b _ => go b + | .letE _ _ v b _ => go v; go b + | .mdata _ e + | .proj _ _ e => go e + | .app .. => e.withApp fun f args => do + if f.hasLooseBVars then + go f; args.forM go + else + let paramInfo := (← getFunInfo f).paramInfo + for h : i in *...args.size do + let arg := args[i] + if h : i < paramInfo.size then + let pinfo := paramInfo[i] + if pinfo.isExplicit && !pinfo.isProp then + go arg + else + go arg + | _ => return () /-- Checks whether `vars` satisfies the `grind_pattern` constraints attached at `thm`. @@ -672,14 +707,25 @@ In the example above, a `map_map` instance should be added to the logical contex Remark: `proof` is used to extract the universe parameters in the proof. -/ -private def checkConstraints (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : GoalM Bool := do +private def checkConstraints (thm : EMatchTheorem) (gen : Nat) (proof : Expr) (args : Array Expr) : GoalM Bool := do if thm.cnstrs.isEmpty then return true /- **Note**: Only top-level theorems have constraints. -/ let .const declName us := proof | return true let info ← getConstInfo declName thm.cnstrs.allM fun cnstr => do match cnstr with - | .notDefEq lhs rhs => checkNotDefEq info.levelParams us args lhs rhs + | .notDefEq lhs rhs => checkDefEq (expectedResult := false) info.levelParams us args lhs rhs + | .defEq lhs rhs => checkDefEq (expectedResult := true) info.levelParams us args lhs rhs + | .depthLt lhs n => return (← getLHS args lhs).approxDepth.toNat < n + | .isGround lhs => let lhs ← getLHS args lhs; return !lhs.hasFVar && !lhs.hasMVar + | .sizeLt lhs n => checkSize (← getLHS args lhs) n + | .genLt n => return gen < n + | .maxInsts n => + /- + **Note**: We are checking the number of instances produced in the whole proof. + It may be useful to bound the number of instances in the current branch. + -/ + return (← getEMatchTheoremNumInstances thm) + 1 < n | _ => throwError "NIY" /-- @@ -700,7 +746,7 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w return () let (some _, c) ← applyAssignment mvars |>.run c | return () let some _ ← synthesizeInsts mvars bis | return () - if (← checkConstraints thm proof mvars) then + if (← checkConstraints thm c.gen proof mvars) then let proof := mkAppN proof mvars if (← mvars.allM (·.mvarId!.isAssigned)) then addNewInstance thm proof c.gen diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index 7365685e53..1047c54553 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -380,7 +380,7 @@ inductive EMatchTheoremConstraint where | /-- Instantiates the theorem only if its generation is less than `n` -/ - genLt (lhs : Nat) (n : Nat) + genLt (n : Nat) | /-- Constraints of the form `is_ground x`. Instantiates the theorem only if `x` is ground term. diff --git a/src/Lean/Meta/Tactic/Grind/Parser.lean b/src/Lean/Meta/Tactic/Grind/Parser.lean index 51e071cd82..f4e6bc0aa6 100644 --- a/src/Lean/Meta/Tactic/Grind/Parser.lean +++ b/src/Lean/Meta/Tactic/Grind/Parser.lean @@ -16,8 +16,8 @@ def isStrictValue := leading_parser nonReservedSymbol "is_strict_value " >> iden def isGround := leading_parser nonReservedSymbol "is_ground " >> ident >> optional ";" def sizeLt := leading_parser nonReservedSymbol "size " >> ident >> " < " >> numLit >> optional ";" def depthLt := leading_parser nonReservedSymbol "depth " >> ident >> " < " >> numLit >> optional ";" -def genLt := leading_parser nonReservedSymbol "gen " >> ident >> " < " >> numLit >> optional ";" -def maxInsts := leading_parser nonReservedSymbol "max_insts " >> numLit >> optional ";" +def genLt := leading_parser nonReservedSymbol "gen" >> " < " >> numLit >> optional ";" +def maxInsts := leading_parser nonReservedSymbol "max_insts" >> " < " >> numLit >> optional ";" def guard := leading_parser nonReservedSymbol "guard " >> checkColGe "irrelevant" >> termParser >> optional ";" def check := leading_parser nonReservedSymbol "check " >> checkColGe "irrelevant" >> termParser >> optional ";" def notDefEq := leading_parser atomic (ident >> " =/= ") >> checkColGe "irrelevant" >> termParser >> optional ";" diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 72fcc54a23..ed1f53ec3d 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -358,6 +358,9 @@ private def incCounter [Hashable α] [BEq α] (s : PHashMap α Nat) (k : α) : P private def saveEMatchTheorem (thm : EMatchTheorem) : GrindM Unit := do modify fun s => { s with counters.thm := incCounter s.counters.thm thm.origin } +def getEMatchTheoremNumInstances (thm : EMatchTheorem) : GrindM Nat := do + return (← get).counters.thm.find? thm.origin |>.getD 0 + def saveCases (declName : Name) : GrindM Unit := do modify fun s => { s with counters.case := incCounter s.counters.case declName } diff --git a/tests/lean/run/grind_pattern_cnstr_2.lean b/tests/lean/run/grind_pattern_cnstr_2.lean new file mode 100644 index 0000000000..71797ab497 --- /dev/null +++ b/tests/lean/run/grind_pattern_cnstr_2.lean @@ -0,0 +1,110 @@ + +namespace Ex1 +opaque f : Nat → Nat +axiom fax : f x ≥ f (f x) + +grind_pattern fax => f x where + depth x < 2 + +/-- +trace: [grind.ematch.instance] fax: f a ≥ f (f a) +[grind.ematch.instance] fax: f (f a) ≥ f (f (f a)) +-/ +#guard_msgs (drop error, trace) in +set_option trace.grind.ematch.instance true in +example (h : f a = 0) : False := by + grind +end Ex1 + +namespace Ex2 +opaque f : Nat → Nat +axiom fax : f x ≥ f (f x) + +grind_pattern fax => f x where + is_ground x + depth x < 3 + +opaque b : Nat + +-- Theorems containing `a` should not be instantiate since it is a local variable +/-- +trace: [grind.ematch.instance] fax: f b ≥ f (f b) +[grind.ematch.instance] fax: f (f b) ≥ f (f (f b)) +[grind.ematch.instance] fax: f (f (f b)) ≥ f (f (f (f b))) +-/ +#guard_msgs (drop error, trace) in +set_option trace.grind.ematch.instance true in +example : f a = 0 → f b = 0 → False := by + grind +end Ex2 + +namespace Ex3 +def f {α : Type} : α → α → α := fun x _ => x +axiom fax [LE α] (x : α) : f x x ≥ f (f x x) (f x x) + +grind_pattern fax => f x x where + size x < 5 + +/-- +trace: [grind.ematch.instance] fax: f a a ≥ f (f a a) (f a a) +[grind.ematch.instance] fax: f (f a a) (f a a) ≥ f (f (f a a) (f a a)) (f (f a a) (f a a)) +-/ +#guard_msgs (drop error, trace) in +set_option trace.grind.ematch.instance true in +example (a b : List (List Nat)) : f a a = b → False := by + grind +end Ex3 + +namespace Ex4 +def f {α : Type} : α → α → α := fun x _ => x +axiom fax [LE α] (x : α) : f x x ≥ f (f x x) (f x x) + +grind_pattern fax => f x x where + gen < 2 + +/-- +trace: [grind.ematch.instance] fax: f a a ≥ f (f a a) (f a a) +[grind.ematch.instance] fax: f (f a a) (f a a) ≥ f (f (f a a) (f a a)) (f (f a a) (f a a)) +-/ +#guard_msgs (drop error, trace) in +set_option trace.grind.ematch.instance true in +example (a b : List (List Nat)) : f a a = b → False := by + grind +end Ex4 + + +namespace Ex5 +opaque f : Nat → Nat +axiom fax : f x ≥ f (f x) + +grind_pattern fax => f x where + max_insts < 4 + +/-- +trace: [grind.ematch.instance] fax: f c ≥ f (f c) +[grind.ematch.instance] fax: f b ≥ f (f b) +[grind.ematch.instance] fax: f a ≥ f (f a) +-/ +#guard_msgs (drop error, trace) in +set_option trace.grind.ematch.instance true in +example : f a = 0 → f b = 0 → f c = 0 → False := by + grind + +end Ex5 + +namespace Ex6 + +opaque g : List Nat → Nat +opaque f : List Nat → List Nat +axiom gax (as : List Nat) : g as > g (f as) + +grind_pattern gax => g as where + as =?= _ :: _ + +/-- trace: [grind.ematch.instance] gax: g [1, 2, 3] > g (f [1, 2, 3]) -/ +#guard_msgs (drop error, trace) in +set_option trace.grind.ematch.instance true in +example (h : g [1, 2, 3] > 0) : False := by + grind + +end Ex6