feat: add [grind ←=] attribute (#6702)

This PR adds support for equality backward reasoning to `grind`. We can
illustrate the new feature with the following example. Suppose we have a
theorem:
```lean
theorem inv_eq {a b : α} (w : a * b = 1) : inv a = b
```
and we want to instantiate the theorem whenever we are tying to prove
`inv t = s` for some terms `t` and `s`
The attribute `[grind ←]` is not applicable in this case because, by
default, `=` is not eligible for E-matching. The new attribute `[grind
←=]` instructs `grind` to use the equality and consider disequalities in
the `grind` proof state as candidates for E-matching.
This commit is contained in:
Leonardo de Moura 2025-01-19 17:16:01 -08:00 committed by GitHub
parent a062eea204
commit 9b7bd58c14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 122 additions and 4 deletions

View file

@ -11,10 +11,11 @@ namespace Lean.Parser.Attr
syntax grindEq := "="
syntax grindEqBoth := atomic("_" "=" "_")
syntax grindEqRhs := atomic("=" "_")
syntax grindEqBwd := atomic("←" "=")
syntax grindBwd := "←"
syntax grindFwd := "→"
syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindBwd <|> grindFwd
syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd
syntax (name := grind) "grind" (grindThmMod)? : attr

View file

@ -21,6 +21,9 @@ def doNotSimp {α : Sort u} (a : α) : α := a
/-- Gadget for representing offsets `t+k` in patterns. -/
def offset (a b : Nat) : Nat := a + b
/-- Gadget for representing `a = b` in patterns for backward propagation. -/
def eqBwdPattern (a b : α) : Prop := a = b
/--
Gadget for annotating the equalities in `match`-equations conclusions.
`_origin` is the term used to instantiate the `match`-equation using E-matching.

View file

@ -310,12 +310,44 @@ private def main (p : Expr) (cnstrs : List Cnstr) : M Unit := do
modify fun s => { s with choiceStack := [c] }
processChoices
/--
Entry point for matching `lhs ←= rhs` patterns.
It traverses disequalities `a = b`, and tries to solve two matching problems:
1- match `lhs` with `a` and `rhs` with `b`
2- match `lhs` with `b` and `rhs` with `a`
-/
private def matchEqBwdPat (p : Expr) : M Unit := do
let_expr Grind.eqBwdPattern pα plhs prhs := p | return ()
let numParams := (← read).thm.numParams
let assignment := mkArray numParams unassigned
let useMT := (← read).useMT
let gmt := (← getThe Goal).gmt
let false ← getFalseExpr
let mut curr := false
repeat
if (← checkMaxInstancesExceeded) then return ()
let n ← getENode curr
if (n.heqProofs || n.isCongrRoot) &&
(!useMT || n.mt == gmt) then
let_expr Eq α lhs rhs := n.self | pure ()
if (← isDefEq α pα) then
let c₀ : Choice := { cnstrs := [], assignment, gen := n.generation }
let go (lhs rhs : Expr) : M Unit := do
let some c₁ ← matchArg? c₀ plhs lhs |>.run | return ()
let some c₂ ← matchArg? c₁ prhs rhs |>.run | return ()
modify fun s => { s with choiceStack := [c₂] }
processChoices
go lhs rhs
go rhs lhs
if isSameExpr n.next false then return ()
curr := n.next
def ematchTheorem (thm : EMatchTheorem) : M Unit := do
if (← checkMaxInstancesExceeded) then return ()
withReader (fun ctx => { ctx with thm }) do
let ps := thm.patterns
match ps, (← read).useMT with
| [p], _ => main p []
| [p], _ => if isEqBwdPattern p then matchEqBwdPat p else main p []
| p::ps, false => main p (ps.map (.continue ·))
| _::_, true => tryAll ps []
| _, _ => unreachable!

View file

@ -39,6 +39,17 @@ def isOffsetPattern? (pat : Expr) : Option (Expr × Nat) := Id.run do
let .lit (.natVal k) := k | none
return some (pat, k)
def mkEqBwdPattern (u : List Level) (α : Expr) (lhs rhs : Expr) : Expr :=
mkApp3 (mkConst ``Grind.eqBwdPattern u) α lhs rhs
def isEqBwdPattern (e : Expr) : Bool :=
e.isAppOfArity ``Grind.eqBwdPattern 3
def isEqBwdPattern? (e : Expr) : Option (Expr × Expr) :=
let_expr Grind.eqBwdPattern _ lhs rhs := e
| none
some (lhs, rhs)
def preprocessPattern (pat : Expr) (normalizePattern := true) : MetaM Expr := do
let pat ← instantiateMVars pat
let pat ← unfoldReducible pat
@ -314,7 +325,8 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
let some f := getPatternFn? pattern
| throwError "invalid pattern, (non-forbidden) application expected{indentExpr pattern}"
assert! f.isConst || f.isFVar
saveSymbol f.toHeadIndex
unless f.isConstOf ``Grind.eqBwdPattern do
saveSymbol f.toHeadIndex
let mut args := pattern.getAppArgs.toVector
let supportMask ← getPatternSupportMask f args.size
for h : i in [:args.size] do
@ -481,6 +493,8 @@ Pattern variables are represented using de Bruijn indices.
-/
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) : MetaM EMatchTheorem := do
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
if symbols.isEmpty then
throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
@ -523,6 +537,14 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
return (xs.size, pats)
mkEMatchTheoremCore origin levelParams numParams proof patterns
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
let_expr f@Eq α lhs rhs := type
| throwError "invalid E-matching `≠` theorem, conclusion must be an equality{indentExpr type}"
let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs)
return (xs.size, [pat.abstract xs])
mkEMatchTheoremCore origin levelParams numParams proof patterns
/--
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
@ -552,13 +574,14 @@ def getEMatchTheorems : CoreM EMatchTheorems :=
return ematchTheoremsExt.getState (← getEnv)
inductive TheoremKind where
| eqLhs | eqRhs | eqBoth | fwd | bwd | default
| eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default
deriving Inhabited, BEq, Repr
private def TheoremKind.toAttribute : TheoremKind → String
| .eqLhs => "[grind =]"
| .eqRhs => "[grind =_]"
| .eqBoth => "[grind _=_]"
| .eqBwd => "[grind ←=]"
| .fwd => "[grind →]"
| .bwd => "[grind ←]"
| .default => "[grind]"
@ -567,6 +590,7 @@ private def TheoremKind.explainFailure : TheoremKind → String
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
| .eqBoth => unreachable! -- eqBoth is a macro
| .eqBwd => "failed to use theorem's conclusion as a pattern"
| .fwd => "failed to find patterns in the antecedents of the theorem"
| .bwd => "failed to find patterns in the theorem's conclusion"
| .default => "failed to find patterns"
@ -656,6 +680,8 @@ def mkEMatchTheoremWithKind? (origin : Origin) (levelParams : Array Name) (proof
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true))
else if kind == .eqRhs then
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false))
else if kind == .eqBwd then
return (← mkEMatchEqBwdTheoremCore origin levelParams proof)
let type ← inferType proof
forallTelescopeReducing type fun xs type => do
let searchPlaces ← match kind with
@ -687,6 +713,7 @@ def getTheoremKindCore (stx : Syntax) : CoreM TheoremKind := do
| `(Parser.Attr.grindThmMod| ←) => return .bwd
| `(Parser.Attr.grindThmMod| =_) => return .eqRhs
| `(Parser.Attr.grindThmMod| _=_) => return .eqBoth
| `(Parser.Attr.grindThmMod| ←=) => return .eqBwd
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"
/-- Return theorem kind for `stx` of the form `(Attr.grindThmMod)?` -/

View file

@ -0,0 +1,55 @@
theorem dummy (x : Nat) : x = x :=
rfl
/--
error: invalid pattern for `dummy`
[@Lean.Grind.eqBwdPattern `[Nat] #0 #0]
the pattern does not contain constant symbols for indexing
-/
#guard_msgs in
attribute [grind ←=] dummy
def α : Type := sorry
def inv : αα := sorry
def mul : ααα := sorry
def one : α := sorry
theorem inv_eq {a b : α} (w : mul a b = one) : inv a = b := sorry
/--
info: [grind.ematch.pattern] inv_eq: [@Lean.Grind.eqBwdPattern `[α] (inv #2) #1]
-/
#guard_msgs in
set_option trace.grind.ematch.pattern true in
attribute [grind ←=] inv_eq
example {a b : α} (w : mul a b = one) : inv a = b := by
grind
structure S where
f : Bool → α
h : mul (f true) (f false) = one
h' : mul (f false) (f true) = one
attribute [grind =] S.h S.h'
example (s : S) : inv (s.f true) = s.f false := by
grind
example (s : S) : s.f false = inv (s.f true) := by
grind
example (s : S) : a = false → s.f a = inv (s.f true) := by
grind
example (s : S) : a ≠ s.f false → a = inv (s.f true) → False := by
grind
/--
info: [grind.ematch.instance] inv_eq: mul (s.f true) (s.f false) = one → inv (s.f true) = s.f false
[grind.ematch.instance] S.h: mul (s.f true) (s.f false) = one
-/
#guard_msgs (info) in
set_option trace.grind.ematch.instance true in
example (s : S) : inv (s.f true) = s.f false := by
grind