feat: some grind_pattern constraints (#11405)
This PR implements the following `grind_pattern` constraints: ```lean grind_pattern fax => f x where depth x < 2 grind_pattern fax => f x where is_ground x grind_pattern fax => f x where size x < 5 grind_pattern fax => f x where gen < 2 grind_pattern fax => f x where max_insts < 4 grind_pattern gax => g as where as =?= _ :: _ ```
This commit is contained in:
parent
799d594400
commit
16740a1540
6 changed files with 173 additions and 15 deletions
|
|
@ -100,8 +100,7 @@ where
|
|||
else if kind == ``defEq then
|
||||
elabDefEq xs cnstr[0] cnstr[2]
|
||||
else if kind == ``genLt then
|
||||
let (_, lhs) ← findLHS xs cnstr[1]
|
||||
return .genLt lhs cnstr[3].toNat
|
||||
return .genLt cnstr[2].toNat
|
||||
else if kind == ``sizeLt then
|
||||
let (_, lhs) ← findLHS xs cnstr[1]
|
||||
return .sizeLt lhs cnstr[3].toNat
|
||||
|
|
@ -109,7 +108,7 @@ where
|
|||
let (_, lhs) ← findLHS xs cnstr[1]
|
||||
return .depthLt lhs cnstr[3].toNat
|
||||
else if kind == ``maxInsts then
|
||||
return .maxInsts cnstr[1].toNat
|
||||
return .maxInsts cnstr[2].toNat
|
||||
else if kind == ``isValue then
|
||||
let (_, lhs) ← findLHS xs cnstr[1]
|
||||
return .isValue lhs false
|
||||
|
|
|
|||
|
|
@ -637,13 +637,17 @@ private abbrev withFreshNGen (x : M α) : M α := do
|
|||
finally
|
||||
setNGen ngen
|
||||
|
||||
/--
|
||||
Checks constraints of the form `lhs =/= rhs`.
|
||||
-/
|
||||
private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do
|
||||
private def getLHS (args : Array Expr) (lhs : Nat) : MetaM Expr := do
|
||||
unless lhs < args.size do
|
||||
throwError "`grind` internal error, invalid variable in `grind_pattern` constraint"
|
||||
let lhs := args[args.size - lhs - 1]!
|
||||
instantiateMVars args[args.size - lhs - 1]!
|
||||
|
||||
/--
|
||||
Checks constraints of the form `lhs =/= rhs` and `lhs =?= rhs`.
|
||||
`expectedResult` is `true` if `lhs` and `rhs` should be definitionally equal.
|
||||
-/
|
||||
private def checkDefEq (expectedResult : Bool) (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do
|
||||
let lhs ← getLHS args lhs
|
||||
/- **Note**: We first instantiate the theorem variables and universes occurring in `rhs`. -/
|
||||
let rhsExpr := rhs.expr.instantiateRev args
|
||||
let rhsExpr := rhsExpr.instantiateLevelParams levelParams us
|
||||
|
|
@ -657,7 +661,38 @@ private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Ar
|
|||
let (_, _, rhsExpr) ← lambdaMetaTelescope rhsExpr (some rhs.numMVars)
|
||||
/- **Note**: We used the guarded version to ensure type errors will not interrupt `grind`. -/
|
||||
let defEq ← isDefEqGuarded lhs rhsExpr
|
||||
return !defEq
|
||||
return defEq == expectedResult
|
||||
|
||||
/--
|
||||
Helper function for checking grind pattern constraints of the form `size e < threshold`
|
||||
Implicit arguments and type information in lambdas and let-expressions are ignored.
|
||||
-/
|
||||
partial def checkSize (e : Expr) (threshold : Nat) : MetaM Bool :=
|
||||
return (← go e |>.run |>.run 0).1.isSome
|
||||
where
|
||||
go (e : Expr) : OptionT (StateT Nat MetaM) Unit := do
|
||||
guard ((← get) < threshold)
|
||||
modify (·+1)
|
||||
match e with
|
||||
| .forallE _ d b _ => go d; go b
|
||||
| .lam _ _ b _ => go b
|
||||
| .letE _ _ v b _ => go v; go b
|
||||
| .mdata _ e
|
||||
| .proj _ _ e => go e
|
||||
| .app .. => e.withApp fun f args => do
|
||||
if f.hasLooseBVars then
|
||||
go f; args.forM go
|
||||
else
|
||||
let paramInfo := (← getFunInfo f).paramInfo
|
||||
for h : i in *...args.size do
|
||||
let arg := args[i]
|
||||
if h : i < paramInfo.size then
|
||||
let pinfo := paramInfo[i]
|
||||
if pinfo.isExplicit && !pinfo.isProp then
|
||||
go arg
|
||||
else
|
||||
go arg
|
||||
| _ => return ()
|
||||
|
||||
/--
|
||||
Checks whether `vars` satisfies the `grind_pattern` constraints attached at `thm`.
|
||||
|
|
@ -672,14 +707,25 @@ In the example above, a `map_map` instance should be added to the logical contex
|
|||
|
||||
Remark: `proof` is used to extract the universe parameters in the proof.
|
||||
-/
|
||||
private def checkConstraints (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : GoalM Bool := do
|
||||
private def checkConstraints (thm : EMatchTheorem) (gen : Nat) (proof : Expr) (args : Array Expr) : GoalM Bool := do
|
||||
if thm.cnstrs.isEmpty then return true
|
||||
/- **Note**: Only top-level theorems have constraints. -/
|
||||
let .const declName us := proof | return true
|
||||
let info ← getConstInfo declName
|
||||
thm.cnstrs.allM fun cnstr => do
|
||||
match cnstr with
|
||||
| .notDefEq lhs rhs => checkNotDefEq info.levelParams us args lhs rhs
|
||||
| .notDefEq lhs rhs => checkDefEq (expectedResult := false) info.levelParams us args lhs rhs
|
||||
| .defEq lhs rhs => checkDefEq (expectedResult := true) info.levelParams us args lhs rhs
|
||||
| .depthLt lhs n => return (← getLHS args lhs).approxDepth.toNat < n
|
||||
| .isGround lhs => let lhs ← getLHS args lhs; return !lhs.hasFVar && !lhs.hasMVar
|
||||
| .sizeLt lhs n => checkSize (← getLHS args lhs) n
|
||||
| .genLt n => return gen < n
|
||||
| .maxInsts n =>
|
||||
/-
|
||||
**Note**: We are checking the number of instances produced in the whole proof.
|
||||
It may be useful to bound the number of instances in the current branch.
|
||||
-/
|
||||
return (← getEMatchTheoremNumInstances thm) + 1 < n
|
||||
| _ => throwError "NIY"
|
||||
|
||||
/--
|
||||
|
|
@ -700,7 +746,7 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
|
|||
return ()
|
||||
let (some _, c) ← applyAssignment mvars |>.run c | return ()
|
||||
let some _ ← synthesizeInsts mvars bis | return ()
|
||||
if (← checkConstraints thm proof mvars) then
|
||||
if (← checkConstraints thm c.gen proof mvars) then
|
||||
let proof := mkAppN proof mvars
|
||||
if (← mvars.allM (·.mvarId!.isAssigned)) then
|
||||
addNewInstance thm proof c.gen
|
||||
|
|
|
|||
|
|
@ -380,7 +380,7 @@ inductive EMatchTheoremConstraint where
|
|||
| /--
|
||||
Instantiates the theorem only if its generation is less than `n`
|
||||
-/
|
||||
genLt (lhs : Nat) (n : Nat)
|
||||
genLt (n : Nat)
|
||||
| /--
|
||||
Constraints of the form `is_ground x`. Instantiates the theorem only if
|
||||
`x` is ground term.
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ def isStrictValue := leading_parser nonReservedSymbol "is_strict_value " >> iden
|
|||
def isGround := leading_parser nonReservedSymbol "is_ground " >> ident >> optional ";"
|
||||
def sizeLt := leading_parser nonReservedSymbol "size " >> ident >> " < " >> numLit >> optional ";"
|
||||
def depthLt := leading_parser nonReservedSymbol "depth " >> ident >> " < " >> numLit >> optional ";"
|
||||
def genLt := leading_parser nonReservedSymbol "gen " >> ident >> " < " >> numLit >> optional ";"
|
||||
def maxInsts := leading_parser nonReservedSymbol "max_insts " >> numLit >> optional ";"
|
||||
def genLt := leading_parser nonReservedSymbol "gen" >> " < " >> numLit >> optional ";"
|
||||
def maxInsts := leading_parser nonReservedSymbol "max_insts" >> " < " >> numLit >> optional ";"
|
||||
def guard := leading_parser nonReservedSymbol "guard " >> checkColGe "irrelevant" >> termParser >> optional ";"
|
||||
def check := leading_parser nonReservedSymbol "check " >> checkColGe "irrelevant" >> termParser >> optional ";"
|
||||
def notDefEq := leading_parser atomic (ident >> " =/= ") >> checkColGe "irrelevant" >> termParser >> optional ";"
|
||||
|
|
|
|||
|
|
@ -358,6 +358,9 @@ private def incCounter [Hashable α] [BEq α] (s : PHashMap α Nat) (k : α) : P
|
|||
private def saveEMatchTheorem (thm : EMatchTheorem) : GrindM Unit := do
|
||||
modify fun s => { s with counters.thm := incCounter s.counters.thm thm.origin }
|
||||
|
||||
def getEMatchTheoremNumInstances (thm : EMatchTheorem) : GrindM Nat := do
|
||||
return (← get).counters.thm.find? thm.origin |>.getD 0
|
||||
|
||||
def saveCases (declName : Name) : GrindM Unit := do
|
||||
modify fun s => { s with counters.case := incCounter s.counters.case declName }
|
||||
|
||||
|
|
|
|||
110
tests/lean/run/grind_pattern_cnstr_2.lean
Normal file
110
tests/lean/run/grind_pattern_cnstr_2.lean
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
|
||||
namespace Ex1
|
||||
opaque f : Nat → Nat
|
||||
axiom fax : f x ≥ f (f x)
|
||||
|
||||
grind_pattern fax => f x where
|
||||
depth x < 2
|
||||
|
||||
/--
|
||||
trace: [grind.ematch.instance] fax: f a ≥ f (f a)
|
||||
[grind.ematch.instance] fax: f (f a) ≥ f (f (f a))
|
||||
-/
|
||||
#guard_msgs (drop error, trace) in
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example (h : f a = 0) : False := by
|
||||
grind
|
||||
end Ex1
|
||||
|
||||
namespace Ex2
|
||||
opaque f : Nat → Nat
|
||||
axiom fax : f x ≥ f (f x)
|
||||
|
||||
grind_pattern fax => f x where
|
||||
is_ground x
|
||||
depth x < 3
|
||||
|
||||
opaque b : Nat
|
||||
|
||||
-- Theorems containing `a` should not be instantiate since it is a local variable
|
||||
/--
|
||||
trace: [grind.ematch.instance] fax: f b ≥ f (f b)
|
||||
[grind.ematch.instance] fax: f (f b) ≥ f (f (f b))
|
||||
[grind.ematch.instance] fax: f (f (f b)) ≥ f (f (f (f b)))
|
||||
-/
|
||||
#guard_msgs (drop error, trace) in
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example : f a = 0 → f b = 0 → False := by
|
||||
grind
|
||||
end Ex2
|
||||
|
||||
namespace Ex3
|
||||
def f {α : Type} : α → α → α := fun x _ => x
|
||||
axiom fax [LE α] (x : α) : f x x ≥ f (f x x) (f x x)
|
||||
|
||||
grind_pattern fax => f x x where
|
||||
size x < 5
|
||||
|
||||
/--
|
||||
trace: [grind.ematch.instance] fax: f a a ≥ f (f a a) (f a a)
|
||||
[grind.ematch.instance] fax: f (f a a) (f a a) ≥ f (f (f a a) (f a a)) (f (f a a) (f a a))
|
||||
-/
|
||||
#guard_msgs (drop error, trace) in
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example (a b : List (List Nat)) : f a a = b → False := by
|
||||
grind
|
||||
end Ex3
|
||||
|
||||
namespace Ex4
|
||||
def f {α : Type} : α → α → α := fun x _ => x
|
||||
axiom fax [LE α] (x : α) : f x x ≥ f (f x x) (f x x)
|
||||
|
||||
grind_pattern fax => f x x where
|
||||
gen < 2
|
||||
|
||||
/--
|
||||
trace: [grind.ematch.instance] fax: f a a ≥ f (f a a) (f a a)
|
||||
[grind.ematch.instance] fax: f (f a a) (f a a) ≥ f (f (f a a) (f a a)) (f (f a a) (f a a))
|
||||
-/
|
||||
#guard_msgs (drop error, trace) in
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example (a b : List (List Nat)) : f a a = b → False := by
|
||||
grind
|
||||
end Ex4
|
||||
|
||||
|
||||
namespace Ex5
|
||||
opaque f : Nat → Nat
|
||||
axiom fax : f x ≥ f (f x)
|
||||
|
||||
grind_pattern fax => f x where
|
||||
max_insts < 4
|
||||
|
||||
/--
|
||||
trace: [grind.ematch.instance] fax: f c ≥ f (f c)
|
||||
[grind.ematch.instance] fax: f b ≥ f (f b)
|
||||
[grind.ematch.instance] fax: f a ≥ f (f a)
|
||||
-/
|
||||
#guard_msgs (drop error, trace) in
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example : f a = 0 → f b = 0 → f c = 0 → False := by
|
||||
grind
|
||||
|
||||
end Ex5
|
||||
|
||||
namespace Ex6
|
||||
|
||||
opaque g : List Nat → Nat
|
||||
opaque f : List Nat → List Nat
|
||||
axiom gax (as : List Nat) : g as > g (f as)
|
||||
|
||||
grind_pattern gax => g as where
|
||||
as =?= _ :: _
|
||||
|
||||
/-- trace: [grind.ematch.instance] gax: g [1, 2, 3] > g (f [1, 2, 3]) -/
|
||||
#guard_msgs (drop error, trace) in
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example (h : g [1, 2, 3] > 0) : False := by
|
||||
grind
|
||||
|
||||
end Ex6
|
||||
Loading…
Add table
Reference in a new issue