feat: [grind?] attribute (#8426)
This PR adds the attribute `[grind?]`. It is like `[grind]` but displays inferred E-matching patterns. It is a more convinient than writing. Thanks @kim-em for suggesting this feature. ```lean set_option trace.grind.ematch.pattern true ``` This PR also improves some tests, and adds helper function `ENode.isRoot`.
This commit is contained in:
parent
a541b8e75e
commit
c28b052576
10 changed files with 102 additions and 48 deletions
|
|
@ -30,6 +30,7 @@ syntax grindIntro := &"intro "
|
|||
syntax grindExt := &"ext "
|
||||
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt
|
||||
syntax (name := grind) "grind" (grindMod)? : attr
|
||||
syntax (name := grind?) "grind?" (grindMod)? : attr
|
||||
end Attr
|
||||
end Lean.Parser
|
||||
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do
|
|||
-- Assign on expressions associated with cutsat terms or interpreted terms
|
||||
for e in goal.exprs do
|
||||
let node ← goal.getENode e
|
||||
if isSameExpr node.root node.self then
|
||||
if node.isRoot then
|
||||
if (← isIntNatENode node) then
|
||||
if let some v ← getAssignment? goal node.self then
|
||||
if v.den == 1 then used := used.insert v.num
|
||||
|
|
@ -111,7 +111,7 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do
|
|||
-- Assign the remaining ones with values not used by cutsat
|
||||
for e in goal.exprs do
|
||||
let node ← goal.getENode e
|
||||
if isSameExpr node.root node.self then
|
||||
if node.isRoot then
|
||||
if (← isIntNatENode node) then
|
||||
if model[node.self]?.isNone then
|
||||
let v := pickUnusedValue goal model node.self nextVal used
|
||||
|
|
|
|||
|
|
@ -49,11 +49,20 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do
|
|||
def throwInvalidUsrModifier : CoreM α :=
|
||||
throwError "the modifier `usr` is only relevant in parameters for `grind only`"
|
||||
|
||||
builtin_initialize
|
||||
/--
|
||||
Auxiliary function for registering `grind` and `grind?` attributes.
|
||||
The `grind?` is an alias for `grind` which displays patterns using `logInfo`.
|
||||
It is just a convenience for users.
|
||||
-/
|
||||
private def registerGrindAttr (showInfo : Bool) : IO Unit :=
|
||||
registerBuiltinAttribute {
|
||||
name := `grind
|
||||
name := if showInfo then `grind? else `grind
|
||||
descr :=
|
||||
"The `[grind]` attribute is used to annotate declarations.\
|
||||
let header := if showInfo then
|
||||
"The `[grind?]` attribute is identical to the `[grind]` attribute, but displays inferred pattern information."
|
||||
else
|
||||
"The `[grind]` attribute is used to annotate declarations."
|
||||
header ++ "\
|
||||
\
|
||||
When applied to an equational theorem, `[grind =]`, `[grind =_]`, or `[grind _=_]`\
|
||||
will mark the theorem for use in heuristic instantiations by the `grind` tactic,
|
||||
|
|
@ -73,12 +82,12 @@ builtin_initialize
|
|||
add := fun declName stx attrKind => MetaM.run' do
|
||||
match (← getAttrKindFromOpt stx) with
|
||||
| .ematch .user => throwInvalidUsrModifier
|
||||
| .ematch k => addEMatchAttr declName attrKind k
|
||||
| .ematch k => addEMatchAttr declName attrKind k (showInfo := showInfo)
|
||||
| .cases eager => addCasesAttr declName eager attrKind
|
||||
| .intro =>
|
||||
if let some info ← isCasesAttrPredicateCandidate? declName false then
|
||||
for ctor in info.ctors do
|
||||
addEMatchAttr ctor attrKind .default
|
||||
addEMatchAttr ctor attrKind .default (showInfo := showInfo)
|
||||
else
|
||||
throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate"
|
||||
| .ext => addExtAttr declName attrKind
|
||||
|
|
@ -89,10 +98,12 @@ builtin_initialize
|
|||
-- If it is an inductive predicate,
|
||||
-- we also add the constructors (intro rules) as E-matching rules
|
||||
for ctor in info.ctors do
|
||||
addEMatchAttr ctor attrKind .default
|
||||
addEMatchAttr ctor attrKind .default (showInfo := showInfo)
|
||||
else
|
||||
addEMatchAttr declName attrKind .default
|
||||
addEMatchAttr declName attrKind .default (showInfo := showInfo)
|
||||
erase := fun declName => MetaM.run' do
|
||||
if showInfo then
|
||||
throwError "`[grind?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[grind]` instead"
|
||||
if (← isCasesAttrCandidate declName false) then
|
||||
eraseCasesAttr declName
|
||||
else if (← isExtTheorem declName) then
|
||||
|
|
@ -101,4 +112,8 @@ builtin_initialize
|
|||
eraseEMatchAttr declName
|
||||
}
|
||||
|
||||
builtin_initialize
|
||||
registerGrindAttr true
|
||||
registerGrindAttr false
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
|
|||
|
|
@ -589,18 +589,24 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) :
|
|||
msg := msg ++ m!"{x} : {← inferType x}"
|
||||
addMessageContextFull msg
|
||||
|
||||
private def logPatternWhen (showInfo : Bool) (origin : Origin) (patterns : List Expr) : MetaM Unit := do
|
||||
if showInfo then
|
||||
logInfo m!"{← origin.pp}: {patterns.map ppPattern}"
|
||||
|
||||
/--
|
||||
Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns.
|
||||
Pattern variables are represented using de Bruijn indices.
|
||||
-/
|
||||
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) (kind : EMatchTheoremKind) : MetaM EMatchTheorem := do
|
||||
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr)
|
||||
(patterns : List Expr) (kind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do
|
||||
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
|
||||
if symbols.isEmpty then
|
||||
throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
|
||||
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
|
||||
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
|
||||
if let .missing pos ← checkCoverage proof numParams bvarFound then
|
||||
let pats : MessageData := m!"{patterns.map ppPattern}"
|
||||
throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
|
||||
logPatternWhen showInfo origin patterns
|
||||
return {
|
||||
proof, patterns, numParams, symbols
|
||||
levelParams, origin, kind
|
||||
|
|
@ -627,7 +633,7 @@ Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs
|
|||
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
|
||||
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the pattern.
|
||||
-/
|
||||
def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) : MetaM EMatchTheorem := do
|
||||
def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) (showInfo := false) : MetaM EMatchTheorem := do
|
||||
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
|
||||
let (lhs, rhs) ← match_expr type with
|
||||
| Eq _ lhs rhs => pure (lhs, rhs)
|
||||
|
|
@ -640,15 +646,15 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
|
|||
trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: after preprocessing: {pat}, {← normalize pat normConfig}"
|
||||
let pats := splitWhileForbidden (pat.abstract xs)
|
||||
return (xs.size, pats)
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs)
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs) (showInfo := showInfo)
|
||||
|
||||
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do
|
||||
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (showInfo := false) : MetaM EMatchTheorem := do
|
||||
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
|
||||
let_expr f@Eq α lhs rhs := type
|
||||
| throwError "invalid E-matching `←=` theorem, conclusion must be an equality{indentExpr type}"
|
||||
let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs)
|
||||
return (xs.size, [pat.abstract xs])
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd (showInfo := showInfo)
|
||||
|
||||
/--
|
||||
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
|
||||
|
|
@ -657,8 +663,8 @@ creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
|
|||
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the
|
||||
pattern.
|
||||
-/
|
||||
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) : MetaM EMatchTheorem := do
|
||||
mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs
|
||||
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) (showInfo := false) : MetaM EMatchTheorem := do
|
||||
mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs (showInfo := showInfo)
|
||||
|
||||
/--
|
||||
Adds an E-matching theorem to the environment.
|
||||
|
|
@ -844,13 +850,13 @@ since the theorem is already in the `grind` state and there is nothing to be ins
|
|||
-/
|
||||
def mkEMatchTheoremWithKind?
|
||||
(origin : Origin) (levelParams : Array Name) (proof : Expr) (kind : EMatchTheoremKind)
|
||||
(groundPatterns := true) : MetaM (Option EMatchTheorem) := do
|
||||
(groundPatterns := true) (showInfo := false) : MetaM (Option EMatchTheorem) := do
|
||||
if kind == .eqLhs then
|
||||
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true))
|
||||
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true) (showInfo := showInfo))
|
||||
else if kind == .eqRhs then
|
||||
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false))
|
||||
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false) (showInfo := showInfo))
|
||||
else if kind == .eqBwd then
|
||||
return (← mkEMatchEqBwdTheoremCore origin levelParams proof)
|
||||
return (← mkEMatchEqBwdTheoremCore origin levelParams proof (showInfo := showInfo))
|
||||
let type ← inferType proof
|
||||
/-
|
||||
Remark: we should not use `forallTelescopeReducing` (with default reducibility) here
|
||||
|
|
@ -894,25 +900,26 @@ where
|
|||
return none
|
||||
let numParams := xs.size
|
||||
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
|
||||
logPatternWhen showInfo origin patterns
|
||||
return some {
|
||||
proof, patterns, numParams, symbols
|
||||
levelParams, origin, kind
|
||||
}
|
||||
|
||||
def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) : MetaM EMatchTheorem := do
|
||||
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind
|
||||
def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do
|
||||
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind (showInfo := showInfo)
|
||||
| throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command"
|
||||
return thm
|
||||
|
||||
def mkEMatchEqTheoremsForDef? (declName : Name) : MetaM (Option (Array EMatchTheorem)) := do
|
||||
def mkEMatchEqTheoremsForDef? (declName : Name) (showInfo := false) : MetaM (Option (Array EMatchTheorem)) := do
|
||||
let some eqns ← getEqnsFor? declName | return none
|
||||
eqns.mapM fun eqn => do
|
||||
mkEMatchEqTheorem eqn (normalizePattern := true)
|
||||
mkEMatchEqTheorem eqn (normalizePattern := true) (showInfo := showInfo)
|
||||
|
||||
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) : MetaM Unit := do
|
||||
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do
|
||||
if wasOriginallyTheorem (← getEnv) declName then
|
||||
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs)) attrKind
|
||||
else if let some thms ← mkEMatchEqTheoremsForDef? declName then
|
||||
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (showInfo := showInfo)) attrKind
|
||||
else if let some thms ← mkEMatchEqTheoremsForDef? declName (showInfo := showInfo) then
|
||||
unless useLhs do
|
||||
throwError "`{declName}` is a definition, you must only use the left-hand side for extracting patterns"
|
||||
thms.forM (ematchTheoremsExt.add · attrKind)
|
||||
|
|
@ -935,20 +942,20 @@ def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMat
|
|||
throwErr
|
||||
return s.erase <| .decl declName
|
||||
|
||||
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) : MetaM Unit := do
|
||||
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM Unit := do
|
||||
if thmKind == .eqLhs then
|
||||
addGrindEqAttr declName attrKind thmKind (useLhs := true)
|
||||
addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
|
||||
else if thmKind == .eqRhs then
|
||||
addGrindEqAttr declName attrKind thmKind (useLhs := false)
|
||||
addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
|
||||
else if thmKind == .eqBoth then
|
||||
addGrindEqAttr declName attrKind thmKind (useLhs := true)
|
||||
addGrindEqAttr declName attrKind thmKind (useLhs := false)
|
||||
addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
|
||||
addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
|
||||
else
|
||||
let info ← getConstInfo declName
|
||||
if !wasOriginallyTheorem (← getEnv) declName && !info.isCtor && !info.isAxiom then
|
||||
addGrindEqAttr declName attrKind thmKind
|
||||
addGrindEqAttr declName attrKind thmKind (showInfo := showInfo)
|
||||
else
|
||||
let thm ← mkEMatchTheoremForDecl declName thmKind
|
||||
let thm ← mkEMatchTheoremForDecl declName thmKind (showInfo := showInfo)
|
||||
ematchTheoremsExt.add thm attrKind
|
||||
|
||||
def eraseEMatchAttr (declName : Name) : MetaM Unit := do
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ def checkInvariants (expensive := false) : GoalM Unit := do
|
|||
for e in (← getExprs) do
|
||||
let node ← getENode e
|
||||
checkParents node.self
|
||||
if isSameExpr node.self node.root then
|
||||
if node.isRoot then
|
||||
checkEqc node
|
||||
if expensive then
|
||||
checkPtrEqImpliesStructEq
|
||||
|
|
|
|||
|
|
@ -343,6 +343,9 @@ structure ENode where
|
|||
-- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3.
|
||||
deriving Inhabited, Repr
|
||||
|
||||
def ENode.isRoot (n : ENode) :=
|
||||
isSameExpr n.self n.root
|
||||
|
||||
def ENode.isCongrRoot (n : ENode) :=
|
||||
isSameExpr n.self n.congr
|
||||
|
||||
|
|
@ -1250,7 +1253,7 @@ def filterENodes (p : ENode → GoalM Bool) : GoalM (Array ENode) := do
|
|||
def forEachEqcRoot (f : ENode → GoalM Unit) : GoalM Unit := do
|
||||
for e in (← getExprs) do
|
||||
let n ← getENode e
|
||||
if isSameExpr n.self n.root then
|
||||
if n.isRoot then
|
||||
f n
|
||||
|
||||
abbrev Propagator := Expr → GoalM Unit
|
||||
|
|
@ -1302,7 +1305,7 @@ partial def Goal.getEqcs (goal : Goal) : List (List Expr) := Id.run do
|
|||
let mut r : List (List Expr) := []
|
||||
for e in goal.exprs do
|
||||
let some node := goal.getENode? e | pure ()
|
||||
if isSameExpr node.root node.self then
|
||||
if node.isRoot then
|
||||
r := goal.getEqc node.self :: r
|
||||
return r
|
||||
|
||||
|
|
|
|||
|
|
@ -44,3 +44,18 @@ set_option trace.grind.ematch.pattern true in
|
|||
set_option trace.grind.ematch.pattern true in
|
||||
@[grind =>] theorem State.update_le_update (h : State.le σ' σ) : State.le (σ'.update x v) (σ.update x v) :=
|
||||
sorry
|
||||
|
||||
|
||||
namespace Foo
|
||||
|
||||
/-- info: Rtrans: [R #4 #3, R #3 #2] -/
|
||||
#guard_msgs (info) in
|
||||
@[grind? ->]
|
||||
axiom Rtrans {x y z : Nat} : R x y → R y z → R x z
|
||||
|
||||
/-- info: Rtrans': [R #4 #3, R #3 #2] -/
|
||||
#guard_msgs (info) in
|
||||
@[grind? →]
|
||||
axiom Rtrans' {x y z : Nat} : R x y → R y z → R x z
|
||||
|
||||
end Foo
|
||||
|
|
|
|||
|
|
@ -5,15 +5,19 @@ attribute [grind] List.countP_nil List.countP_cons
|
|||
|
||||
theorem List.countP_le_countP (hpq : ∀ x ∈ l, P x → Q x) :
|
||||
l.countP P ≤ l.countP Q := by
|
||||
induction l with
|
||||
| nil => grind
|
||||
| cons x xs ih =>
|
||||
grind
|
||||
induction l <;> grind
|
||||
|
||||
-- TODO: how to explain to the user that `l.countP P ≤ l.countP Q` is a bad pattern
|
||||
grind_pattern List.countP_le_countP => l.countP P, l.countP Q
|
||||
|
||||
theorem List.countP_lt_countP (hpq : ∀ x ∈ l, P x → Q x) (y:α) (hx: y ∈ l) (hxP : P y = false) (hxQ : Q y) :
|
||||
l.countP P < l.countP Q := by
|
||||
induction l with
|
||||
| nil => grind
|
||||
| cons x xs ih =>
|
||||
have : xs.countP P ≤ xs.countP Q := countP_le_countP (by grind)
|
||||
grind
|
||||
induction l <;> grind
|
||||
|
||||
/--
|
||||
info: List.countP_nil: [@List.countP #1 #0 (@List.nil _)]
|
||||
---
|
||||
info: List.countP_cons: [@List.countP #3 #2 (@List.cons _ #1 #0)]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
attribute [grind?] List.countP_nil List.countP_cons
|
||||
|
|
|
|||
|
|
@ -76,3 +76,12 @@ trace: [grind.assert] x1 = appV a_2 b
|
|||
#guard_msgs (trace) in
|
||||
example : x1 = appV a b → x2 = appV x1 c → x3 = appV b c → x4 = appV a x3 → HEq x2 x4 := by
|
||||
grind
|
||||
|
||||
|
||||
/--
|
||||
info: appV_assoc': [@appV #6 #5 (@HAdd.hAdd `[Nat] `[Nat] `[Nat] `[instHAdd] #4 #3) #2 (@appV _ #4 #3 #1 #0)]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
@[grind? =]
|
||||
theorem appV_assoc' (a : Vector α n) (b : Vector α m) (c : Vector α n') :
|
||||
HEq (appV a (appV b c)) (appV (appV a b) c) := sorry
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ theorem length_pos_of_ne_nil {l : List α} (h : l ≠ []) : 0 < l.length := by
|
|||
|
||||
theorem getLast?_dropLast {xs : List α} :
|
||||
xs.dropLast.getLast? = if xs.length ≤ 1 then none else xs[xs.length - 2]? := by
|
||||
grind (splits := 9) only [List.getElem?_eq_none, List.getElem?_reverse, getLast?_eq_getElem?,
|
||||
grind (splits := 15) only [List.getElem?_eq_none, List.getElem?_reverse, getLast?_eq_getElem?,
|
||||
List.head?_eq_getLast?_reverse, getElem?_dropLast, List.getLast?_reverse, List.length_dropLast,
|
||||
List.length_reverse, length_nil, List.reverse_reverse, head?_nil, List.getElem?_eq_none,
|
||||
length_pos_of_ne_nil, getLast?_nil, List.head?_reverse, List.getLast?_eq_head?_reverse,
|
||||
Loading…
Add table
Reference in a new issue