feat: add [grind ←=] attribute (#6702)
This PR adds support for equality backward reasoning to `grind`. We can
illustrate the new feature with the following example. Suppose we have a
theorem:
```lean
theorem inv_eq {a b : α} (w : a * b = 1) : inv a = b
```
and we want to instantiate the theorem whenever we are tying to prove
`inv t = s` for some terms `t` and `s`
The attribute `[grind ←]` is not applicable in this case because, by
default, `=` is not eligible for E-matching. The new attribute `[grind
←=]` instructs `grind` to use the equality and consider disequalities in
the `grind` proof state as candidates for E-matching.
This commit is contained in:
parent
a062eea204
commit
9b7bd58c14
5 changed files with 122 additions and 4 deletions
|
|
@ -11,10 +11,11 @@ namespace Lean.Parser.Attr
|
|||
syntax grindEq := "="
|
||||
syntax grindEqBoth := atomic("_" "=" "_")
|
||||
syntax grindEqRhs := atomic("=" "_")
|
||||
syntax grindEqBwd := atomic("←" "=")
|
||||
syntax grindBwd := "←"
|
||||
syntax grindFwd := "→"
|
||||
|
||||
syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindBwd <|> grindFwd
|
||||
syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd
|
||||
|
||||
syntax (name := grind) "grind" (grindThmMod)? : attr
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ def doNotSimp {α : Sort u} (a : α) : α := a
|
|||
/-- Gadget for representing offsets `t+k` in patterns. -/
|
||||
def offset (a b : Nat) : Nat := a + b
|
||||
|
||||
/-- Gadget for representing `a = b` in patterns for backward propagation. -/
|
||||
def eqBwdPattern (a b : α) : Prop := a = b
|
||||
|
||||
/--
|
||||
Gadget for annotating the equalities in `match`-equations conclusions.
|
||||
`_origin` is the term used to instantiate the `match`-equation using E-matching.
|
||||
|
|
|
|||
|
|
@ -310,12 +310,44 @@ private def main (p : Expr) (cnstrs : List Cnstr) : M Unit := do
|
|||
modify fun s => { s with choiceStack := [c] }
|
||||
processChoices
|
||||
|
||||
/--
|
||||
Entry point for matching `lhs ←= rhs` patterns.
|
||||
It traverses disequalities `a = b`, and tries to solve two matching problems:
|
||||
1- match `lhs` with `a` and `rhs` with `b`
|
||||
2- match `lhs` with `b` and `rhs` with `a`
|
||||
-/
|
||||
private def matchEqBwdPat (p : Expr) : M Unit := do
|
||||
let_expr Grind.eqBwdPattern pα plhs prhs := p | return ()
|
||||
let numParams := (← read).thm.numParams
|
||||
let assignment := mkArray numParams unassigned
|
||||
let useMT := (← read).useMT
|
||||
let gmt := (← getThe Goal).gmt
|
||||
let false ← getFalseExpr
|
||||
let mut curr := false
|
||||
repeat
|
||||
if (← checkMaxInstancesExceeded) then return ()
|
||||
let n ← getENode curr
|
||||
if (n.heqProofs || n.isCongrRoot) &&
|
||||
(!useMT || n.mt == gmt) then
|
||||
let_expr Eq α lhs rhs := n.self | pure ()
|
||||
if (← isDefEq α pα) then
|
||||
let c₀ : Choice := { cnstrs := [], assignment, gen := n.generation }
|
||||
let go (lhs rhs : Expr) : M Unit := do
|
||||
let some c₁ ← matchArg? c₀ plhs lhs |>.run | return ()
|
||||
let some c₂ ← matchArg? c₁ prhs rhs |>.run | return ()
|
||||
modify fun s => { s with choiceStack := [c₂] }
|
||||
processChoices
|
||||
go lhs rhs
|
||||
go rhs lhs
|
||||
if isSameExpr n.next false then return ()
|
||||
curr := n.next
|
||||
|
||||
def ematchTheorem (thm : EMatchTheorem) : M Unit := do
|
||||
if (← checkMaxInstancesExceeded) then return ()
|
||||
withReader (fun ctx => { ctx with thm }) do
|
||||
let ps := thm.patterns
|
||||
match ps, (← read).useMT with
|
||||
| [p], _ => main p []
|
||||
| [p], _ => if isEqBwdPattern p then matchEqBwdPat p else main p []
|
||||
| p::ps, false => main p (ps.map (.continue ·))
|
||||
| _::_, true => tryAll ps []
|
||||
| _, _ => unreachable!
|
||||
|
|
|
|||
|
|
@ -39,6 +39,17 @@ def isOffsetPattern? (pat : Expr) : Option (Expr × Nat) := Id.run do
|
|||
let .lit (.natVal k) := k | none
|
||||
return some (pat, k)
|
||||
|
||||
def mkEqBwdPattern (u : List Level) (α : Expr) (lhs rhs : Expr) : Expr :=
|
||||
mkApp3 (mkConst ``Grind.eqBwdPattern u) α lhs rhs
|
||||
|
||||
def isEqBwdPattern (e : Expr) : Bool :=
|
||||
e.isAppOfArity ``Grind.eqBwdPattern 3
|
||||
|
||||
def isEqBwdPattern? (e : Expr) : Option (Expr × Expr) :=
|
||||
let_expr Grind.eqBwdPattern _ lhs rhs := e
|
||||
| none
|
||||
some (lhs, rhs)
|
||||
|
||||
def preprocessPattern (pat : Expr) (normalizePattern := true) : MetaM Expr := do
|
||||
let pat ← instantiateMVars pat
|
||||
let pat ← unfoldReducible pat
|
||||
|
|
@ -314,7 +325,8 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
|
|||
let some f := getPatternFn? pattern
|
||||
| throwError "invalid pattern, (non-forbidden) application expected{indentExpr pattern}"
|
||||
assert! f.isConst || f.isFVar
|
||||
saveSymbol f.toHeadIndex
|
||||
unless f.isConstOf ``Grind.eqBwdPattern do
|
||||
saveSymbol f.toHeadIndex
|
||||
let mut args := pattern.getAppArgs.toVector
|
||||
let supportMask ← getPatternSupportMask f args.size
|
||||
for h : i in [:args.size] do
|
||||
|
|
@ -481,6 +493,8 @@ Pattern variables are represented using de Bruijn indices.
|
|||
-/
|
||||
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) : MetaM EMatchTheorem := do
|
||||
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
|
||||
if symbols.isEmpty then
|
||||
throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
|
||||
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
|
||||
if let .missing pos ← checkCoverage proof numParams bvarFound then
|
||||
let pats : MessageData := m!"{patterns.map ppPattern}"
|
||||
|
|
@ -523,6 +537,14 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
|
|||
return (xs.size, pats)
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns
|
||||
|
||||
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do
|
||||
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
|
||||
let_expr f@Eq α lhs rhs := type
|
||||
| throwError "invalid E-matching `≠` theorem, conclusion must be an equality{indentExpr type}"
|
||||
let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs)
|
||||
return (xs.size, [pat.abstract xs])
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns
|
||||
|
||||
/--
|
||||
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
|
||||
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
|
||||
|
|
@ -552,13 +574,14 @@ def getEMatchTheorems : CoreM EMatchTheorems :=
|
|||
return ematchTheoremsExt.getState (← getEnv)
|
||||
|
||||
inductive TheoremKind where
|
||||
| eqLhs | eqRhs | eqBoth | fwd | bwd | default
|
||||
| eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default
|
||||
deriving Inhabited, BEq, Repr
|
||||
|
||||
private def TheoremKind.toAttribute : TheoremKind → String
|
||||
| .eqLhs => "[grind =]"
|
||||
| .eqRhs => "[grind =_]"
|
||||
| .eqBoth => "[grind _=_]"
|
||||
| .eqBwd => "[grind ←=]"
|
||||
| .fwd => "[grind →]"
|
||||
| .bwd => "[grind ←]"
|
||||
| .default => "[grind]"
|
||||
|
|
@ -567,6 +590,7 @@ private def TheoremKind.explainFailure : TheoremKind → String
|
|||
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
|
||||
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
|
||||
| .eqBoth => unreachable! -- eqBoth is a macro
|
||||
| .eqBwd => "failed to use theorem's conclusion as a pattern"
|
||||
| .fwd => "failed to find patterns in the antecedents of the theorem"
|
||||
| .bwd => "failed to find patterns in the theorem's conclusion"
|
||||
| .default => "failed to find patterns"
|
||||
|
|
@ -656,6 +680,8 @@ def mkEMatchTheoremWithKind? (origin : Origin) (levelParams : Array Name) (proof
|
|||
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true))
|
||||
else if kind == .eqRhs then
|
||||
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false))
|
||||
else if kind == .eqBwd then
|
||||
return (← mkEMatchEqBwdTheoremCore origin levelParams proof)
|
||||
let type ← inferType proof
|
||||
forallTelescopeReducing type fun xs type => do
|
||||
let searchPlaces ← match kind with
|
||||
|
|
@ -687,6 +713,7 @@ def getTheoremKindCore (stx : Syntax) : CoreM TheoremKind := do
|
|||
| `(Parser.Attr.grindThmMod| ←) => return .bwd
|
||||
| `(Parser.Attr.grindThmMod| =_) => return .eqRhs
|
||||
| `(Parser.Attr.grindThmMod| _=_) => return .eqBoth
|
||||
| `(Parser.Attr.grindThmMod| ←=) => return .eqBwd
|
||||
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"
|
||||
|
||||
/-- Return theorem kind for `stx` of the form `(Attr.grindThmMod)?` -/
|
||||
|
|
|
|||
55
tests/lean/run/grind_eq_bwd.lean
Normal file
55
tests/lean/run/grind_eq_bwd.lean
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
theorem dummy (x : Nat) : x = x :=
|
||||
rfl
|
||||
|
||||
/--
|
||||
error: invalid pattern for `dummy`
|
||||
[@Lean.Grind.eqBwdPattern `[Nat] #0 #0]
|
||||
the pattern does not contain constant symbols for indexing
|
||||
-/
|
||||
#guard_msgs in
|
||||
attribute [grind ←=] dummy
|
||||
|
||||
def α : Type := sorry
|
||||
def inv : α → α := sorry
|
||||
def mul : α → α → α := sorry
|
||||
def one : α := sorry
|
||||
|
||||
theorem inv_eq {a b : α} (w : mul a b = one) : inv a = b := sorry
|
||||
|
||||
/--
|
||||
info: [grind.ematch.pattern] inv_eq: [@Lean.Grind.eqBwdPattern `[α] (inv #2) #1]
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.grind.ematch.pattern true in
|
||||
attribute [grind ←=] inv_eq
|
||||
|
||||
example {a b : α} (w : mul a b = one) : inv a = b := by
|
||||
grind
|
||||
|
||||
structure S where
|
||||
f : Bool → α
|
||||
h : mul (f true) (f false) = one
|
||||
h' : mul (f false) (f true) = one
|
||||
|
||||
attribute [grind =] S.h S.h'
|
||||
|
||||
example (s : S) : inv (s.f true) = s.f false := by
|
||||
grind
|
||||
|
||||
example (s : S) : s.f false = inv (s.f true) := by
|
||||
grind
|
||||
|
||||
example (s : S) : a = false → s.f a = inv (s.f true) := by
|
||||
grind
|
||||
|
||||
example (s : S) : a ≠ s.f false → a = inv (s.f true) → False := by
|
||||
grind
|
||||
|
||||
/--
|
||||
info: [grind.ematch.instance] inv_eq: mul (s.f true) (s.f false) = one → inv (s.f true) = s.f false
|
||||
[grind.ematch.instance] S.h: mul (s.f true) (s.f false) = one
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example (s : S) : inv (s.f true) = s.f false := by
|
||||
grind
|
||||
Loading…
Add table
Reference in a new issue