feat: some grind_pattern constraints (#11405)

This PR implements the following `grind_pattern` constraints:
```lean
grind_pattern fax => f x  where
  depth x < 2

grind_pattern fax => f x where
  is_ground x

grind_pattern fax => f x where
  size x < 5

grind_pattern fax => f x where
  gen < 2

grind_pattern fax => f x where
  max_insts < 4

grind_pattern gax => g as where
  as =?= _ :: _
```
This commit is contained in:
Leonardo de Moura 2025-11-27 10:05:47 -08:00 committed by GitHub
parent 799d594400
commit 16740a1540
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 173 additions and 15 deletions

View file

@ -100,8 +100,7 @@ where
else if kind == ``defEq then
elabDefEq xs cnstr[0] cnstr[2]
else if kind == ``genLt then
let (_, lhs) ← findLHS xs cnstr[1]
return .genLt lhs cnstr[3].toNat
return .genLt cnstr[2].toNat
else if kind == ``sizeLt then
let (_, lhs) ← findLHS xs cnstr[1]
return .sizeLt lhs cnstr[3].toNat
@ -109,7 +108,7 @@ where
let (_, lhs) ← findLHS xs cnstr[1]
return .depthLt lhs cnstr[3].toNat
else if kind == ``maxInsts then
return .maxInsts cnstr[1].toNat
return .maxInsts cnstr[2].toNat
else if kind == ``isValue then
let (_, lhs) ← findLHS xs cnstr[1]
return .isValue lhs false

View file

@ -637,13 +637,17 @@ private abbrev withFreshNGen (x : M α) : M α := do
finally
setNGen ngen
/--
Checks constraints of the form `lhs =/= rhs`.
-/
private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do
private def getLHS (args : Array Expr) (lhs : Nat) : MetaM Expr := do
unless lhs < args.size do
throwError "`grind` internal error, invalid variable in `grind_pattern` constraint"
let lhs := args[args.size - lhs - 1]!
instantiateMVars args[args.size - lhs - 1]!
/--
Checks constraints of the form `lhs =/= rhs` and `lhs =?= rhs`.
`expectedResult` is `true` if `lhs` and `rhs` should be definitionally equal.
-/
private def checkDefEq (expectedResult : Bool) (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do
let lhs ← getLHS args lhs
/- **Note**: We first instantiate the theorem variables and universes occurring in `rhs`. -/
let rhsExpr := rhs.expr.instantiateRev args
let rhsExpr := rhsExpr.instantiateLevelParams levelParams us
@ -657,7 +661,38 @@ private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Ar
let (_, _, rhsExpr) ← lambdaMetaTelescope rhsExpr (some rhs.numMVars)
/- **Note**: We used the guarded version to ensure type errors will not interrupt `grind`. -/
let defEq ← isDefEqGuarded lhs rhsExpr
return !defEq
return defEq == expectedResult
/--
Helper function for checking grind pattern constraints of the form `size e < threshold`
Implicit arguments and type information in lambdas and let-expressions are ignored.
-/
partial def checkSize (e : Expr) (threshold : Nat) : MetaM Bool :=
return (← go e |>.run |>.run 0).1.isSome
where
go (e : Expr) : OptionT (StateT Nat MetaM) Unit := do
guard ((← get) < threshold)
modify (·+1)
match e with
| .forallE _ d b _ => go d; go b
| .lam _ _ b _ => go b
| .letE _ _ v b _ => go v; go b
| .mdata _ e
| .proj _ _ e => go e
| .app .. => e.withApp fun f args => do
if f.hasLooseBVars then
go f; args.forM go
else
let paramInfo := (← getFunInfo f).paramInfo
for h : i in *...args.size do
let arg := args[i]
if h : i < paramInfo.size then
let pinfo := paramInfo[i]
if pinfo.isExplicit && !pinfo.isProp then
go arg
else
go arg
| _ => return ()
/--
Checks whether `vars` satisfies the `grind_pattern` constraints attached at `thm`.
@ -672,14 +707,25 @@ In the example above, a `map_map` instance should be added to the logical contex
Remark: `proof` is used to extract the universe parameters in the proof.
-/
private def checkConstraints (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : GoalM Bool := do
private def checkConstraints (thm : EMatchTheorem) (gen : Nat) (proof : Expr) (args : Array Expr) : GoalM Bool := do
if thm.cnstrs.isEmpty then return true
/- **Note**: Only top-level theorems have constraints. -/
let .const declName us := proof | return true
let info ← getConstInfo declName
thm.cnstrs.allM fun cnstr => do
match cnstr with
| .notDefEq lhs rhs => checkNotDefEq info.levelParams us args lhs rhs
| .notDefEq lhs rhs => checkDefEq (expectedResult := false) info.levelParams us args lhs rhs
| .defEq lhs rhs => checkDefEq (expectedResult := true) info.levelParams us args lhs rhs
| .depthLt lhs n => return (← getLHS args lhs).approxDepth.toNat < n
| .isGround lhs => let lhs ← getLHS args lhs; return !lhs.hasFVar && !lhs.hasMVar
| .sizeLt lhs n => checkSize (← getLHS args lhs) n
| .genLt n => return gen < n
| .maxInsts n =>
/-
**Note**: We are checking the number of instances produced in the whole proof.
It may be useful to bound the number of instances in the current branch.
-/
return (← getEMatchTheoremNumInstances thm) + 1 < n
| _ => throwError "NIY"
/--
@ -700,7 +746,7 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
return ()
let (some _, c) ← applyAssignment mvars |>.run c | return ()
let some _ ← synthesizeInsts mvars bis | return ()
if (← checkConstraints thm proof mvars) then
if (← checkConstraints thm c.gen proof mvars) then
let proof := mkAppN proof mvars
if (← mvars.allM (·.mvarId!.isAssigned)) then
addNewInstance thm proof c.gen

View file

@ -380,7 +380,7 @@ inductive EMatchTheoremConstraint where
| /--
Instantiates the theorem only if its generation is less than `n`
-/
genLt (lhs : Nat) (n : Nat)
genLt (n : Nat)
| /--
Constraints of the form `is_ground x`. Instantiates the theorem only if
`x` is ground term.

View file

@ -16,8 +16,8 @@ def isStrictValue := leading_parser nonReservedSymbol "is_strict_value " >> iden
def isGround := leading_parser nonReservedSymbol "is_ground " >> ident >> optional ";"
def sizeLt := leading_parser nonReservedSymbol "size " >> ident >> " < " >> numLit >> optional ";"
def depthLt := leading_parser nonReservedSymbol "depth " >> ident >> " < " >> numLit >> optional ";"
def genLt := leading_parser nonReservedSymbol "gen " >> ident >> " < " >> numLit >> optional ";"
def maxInsts := leading_parser nonReservedSymbol "max_insts " >> numLit >> optional ";"
def genLt := leading_parser nonReservedSymbol "gen" >> " < " >> numLit >> optional ";"
def maxInsts := leading_parser nonReservedSymbol "max_insts" >> " < " >> numLit >> optional ";"
def guard := leading_parser nonReservedSymbol "guard " >> checkColGe "irrelevant" >> termParser >> optional ";"
def check := leading_parser nonReservedSymbol "check " >> checkColGe "irrelevant" >> termParser >> optional ";"
def notDefEq := leading_parser atomic (ident >> " =/= ") >> checkColGe "irrelevant" >> termParser >> optional ";"

View file

@ -358,6 +358,9 @@ private def incCounter [Hashable α] [BEq α] (s : PHashMap α Nat) (k : α) : P
private def saveEMatchTheorem (thm : EMatchTheorem) : GrindM Unit := do
modify fun s => { s with counters.thm := incCounter s.counters.thm thm.origin }
def getEMatchTheoremNumInstances (thm : EMatchTheorem) : GrindM Nat := do
return (← get).counters.thm.find? thm.origin |>.getD 0
def saveCases (declName : Name) : GrindM Unit := do
modify fun s => { s with counters.case := incCounter s.counters.case declName }

View file

@ -0,0 +1,110 @@
namespace Ex1
opaque f : Nat → Nat
axiom fax : f x ≥ f (f x)
grind_pattern fax => f x where
depth x < 2
/--
trace: [grind.ematch.instance] fax: f a ≥ f (f a)
[grind.ematch.instance] fax: f (f a) ≥ f (f (f a))
-/
#guard_msgs (drop error, trace) in
set_option trace.grind.ematch.instance true in
example (h : f a = 0) : False := by
grind
end Ex1
namespace Ex2
opaque f : Nat → Nat
axiom fax : f x ≥ f (f x)
grind_pattern fax => f x where
is_ground x
depth x < 3
opaque b : Nat
-- Theorems containing `a` should not be instantiate since it is a local variable
/--
trace: [grind.ematch.instance] fax: f b ≥ f (f b)
[grind.ematch.instance] fax: f (f b) ≥ f (f (f b))
[grind.ematch.instance] fax: f (f (f b)) ≥ f (f (f (f b)))
-/
#guard_msgs (drop error, trace) in
set_option trace.grind.ematch.instance true in
example : f a = 0 → f b = 0 → False := by
grind
end Ex2
namespace Ex3
def f {α : Type} : ααα := fun x _ => x
axiom fax [LE α] (x : α) : f x x ≥ f (f x x) (f x x)
grind_pattern fax => f x x where
size x < 5
/--
trace: [grind.ematch.instance] fax: f a a ≥ f (f a a) (f a a)
[grind.ematch.instance] fax: f (f a a) (f a a) ≥ f (f (f a a) (f a a)) (f (f a a) (f a a))
-/
#guard_msgs (drop error, trace) in
set_option trace.grind.ematch.instance true in
example (a b : List (List Nat)) : f a a = b → False := by
grind
end Ex3
namespace Ex4
def f {α : Type} : ααα := fun x _ => x
axiom fax [LE α] (x : α) : f x x ≥ f (f x x) (f x x)
grind_pattern fax => f x x where
gen < 2
/--
trace: [grind.ematch.instance] fax: f a a ≥ f (f a a) (f a a)
[grind.ematch.instance] fax: f (f a a) (f a a) ≥ f (f (f a a) (f a a)) (f (f a a) (f a a))
-/
#guard_msgs (drop error, trace) in
set_option trace.grind.ematch.instance true in
example (a b : List (List Nat)) : f a a = b → False := by
grind
end Ex4
namespace Ex5
opaque f : Nat → Nat
axiom fax : f x ≥ f (f x)
grind_pattern fax => f x where
max_insts < 4
/--
trace: [grind.ematch.instance] fax: f c ≥ f (f c)
[grind.ematch.instance] fax: f b ≥ f (f b)
[grind.ematch.instance] fax: f a ≥ f (f a)
-/
#guard_msgs (drop error, trace) in
set_option trace.grind.ematch.instance true in
example : f a = 0 → f b = 0 → f c = 0 → False := by
grind
end Ex5
namespace Ex6
opaque g : List Nat → Nat
opaque f : List Nat → List Nat
axiom gax (as : List Nat) : g as > g (f as)
grind_pattern gax => g as where
as =?= _ :: _
/-- trace: [grind.ematch.instance] gax: g [1, 2, 3] > g (f [1, 2, 3]) -/
#guard_msgs (drop error, trace) in
set_option trace.grind.ematch.instance true in
example (h : g [1, 2, 3] > 0) : False := by
grind
end Ex6