feat: case splitting match-expressions with overlapping patterns in grind (#6735)

This PR adds support for case splitting on `match`-expressions with
overlapping patterns to the `grind` tactic. `grind` can now solve
examples such as:
```
inductive S where
  | mk1 (n : Nat)
  | mk2 (n : Nat) (s : S)
  | mk3 (n : Bool)
  | mk4 (s1 s2 : S)

def g (x y : S) :=
  match x, y with
  | .mk1 a, _ => a + 2
  | _, .mk2 1 (.mk4 _ _) => 3
  | .mk3 _, .mk4 _ _ => 4
  | _, _ => 5

example : g a b > 1 := by
  grind [g.eq_def]
```
This commit is contained in:
Leonardo de Moura 2025-01-21 18:59:42 -08:00 committed by GitHub
parent 3881f21df1
commit de31faa470
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 234 additions and 45 deletions

View file

@ -7,9 +7,43 @@ prelude
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Cases
import Lean.Meta.Match.MatcherApp
import Lean.Meta.Tactic.Grind.MatchCond
namespace Lean.Meta.Grind
/--
Returns `true` if `e` is of the form `∀ ..., _ = _ ... -> False`
-/
private def isMatchCond (e : Expr) : Bool := Id.run do
let mut e := e
let mut hasEqs := false
repeat
let .forallE _ d b _ := e | return false
if d.isEq || d.isHEq then hasEqs := true
if b.isFalse then return hasEqs
e := b
return true
/--
Given a splitter alternative, annotate the terms that are `match`-expression
conditions corresponding to overlapping patterns.
-/
private def addMatchCondsToAlt (alt : Expr) : Expr := Id.run do
let .forallE _ d b _ := alt
| return alt
let d := if isMatchCond d then markAsMatchCond d else d
return alt.updateForallE! d (addMatchCondsToAlt b)
/--
Annotates the `match`-expression conditions in the alternatives in the given
`match` splitter type.
-/
private def addMatchCondsToSplitter (splitterType : Expr) (numAlts : Nat) : Expr := Id.run do
if numAlts == 0 then return splitterType
let .forallE _ alt b _ := splitterType
| return splitterType
return splitterType.updateForallE! (addMatchCondsToAlt alt) (addMatchCondsToSplitter b (numAlts-1))
def casesMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
let some app ← matchMatcherApp? e
| throwTacticEx `grind.casesMatch mvarId m!"`match`-expression expected{indentExpr e}"
@ -23,7 +57,10 @@ def casesMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.with
let splitterApp := mkAppN splitterApp app.params
let splitterApp := mkApp splitterApp motive
let splitterApp := mkAppN splitterApp app.discrs
let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType splitterApp) app.alts.size (kind := .syntheticOpaque)
let numAlts := app.alts.size
let splitterType ← inferType splitterApp
let splitterType := addMatchCondsToSplitter splitterType app.alts.size
let (mvars, _, _) ← forallMetaBoundedTelescope splitterType numAlts (kind := .syntheticOpaque)
let splitterApp := mkAppN splitterApp mvars
let val := mkAppN splitterApp eqRefls
mvarId.assign val

View file

@ -268,6 +268,7 @@ def addNewEq (lhs rhs proof : Expr) (generation : Nat) : GoalM Unit := do
/-- Adds a new `fact` justified by the given proof and using the given generation. -/
def add (fact : Expr) (proof : Expr) (generation := 0) : GoalM Unit := do
if fact.isTrue then return ()
storeFact fact
trace_goal[grind.assert] "{fact}"
if (← isInconsistent) then return ()

View file

@ -216,10 +216,10 @@ Helper function for marking parts of `match`-equation theorem as "do-not-simplif
-/
private partial def annotateMatchEqnType (prop : Expr) (initApp : Expr) : M Expr := do
if let .forallE n d b bi := prop then
let d if (← isProp d) then
let d := if (← isProp d) then
markAsMatchCond d
else
pure d
d
withLocalDecl n bi d fun x => do
mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x) initApp)
else

View file

@ -13,6 +13,7 @@ import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Canon
import Lean.Meta.Tactic.Grind.Beta
import Lean.Meta.Tactic.Grind.MatchCond
import Lean.Meta.Tactic.Grind.Arith.Internalize
namespace Lean.Meta.Grind
@ -127,21 +128,12 @@ private partial def internalizePattern (pattern : Expr) (generation : Nat) : Goa
/-- Internalizes the `MatchCond` gadget. -/
private partial def internalizeMatchCond (matchCond : Expr) (generation : Nat) : GoalM Unit := do
let_expr Grind.MatchCond e ← matchCond | return ()
mkENode' matchCond generation
let mut e := e
repeat
let .forallE _ d b _ := e | break
let internalizeLhs (lhs : Expr) : GoalM Unit := do
unless lhs.hasLooseBVars do
internalize lhs generation
registerParent matchCond lhs
match_expr d with
| Eq _ lhs _ => internalizeLhs lhs
| HEq _ lhs _ _ => internalizeLhs lhs
| _ => pure ()
e := b
let (lhss, e') ← collectMatchCondLhssAndAbstract matchCond
lhss.forM fun lhs => do internalize lhs generation; registerParent matchCond lhs
propagateUp matchCond
internalize e' generation
pushEq matchCond e' (← mkEqRefl matchCond)
partial def activateTheorem (thm : EMatchTheorem) (generation : Nat) : GoalM Unit := do
-- Recall that we use the proof as part of the key for a set of instances found so far.

View file

@ -10,6 +10,67 @@ import Lean.Meta.Tactic.Simp.Simproc
import Lean.Meta.Tactic.Grind.PropagatorAttr
namespace Lean.Meta.Grind
/-!
Support for `match`-expressions with overlapping patterns.
Recall that when a `match`-expression has overlapping patterns, some of its equation theorems are
conditional. Let's consider the following example
```
inductive S where
| mk1 (n : Nat)
| mk2 (n : Nat) (s : S)
| mk3 (n : Bool)
| mk4 (s1 s2 : S)
def f (x y : S) :=
match x, y with
| .mk1 a, c => a + 2
| a, .mk2 1 (.mk4 b c) => 3
| .mk3 a, .mk4 b c => 4
| a, b => 5
```
The `match`-expression in this example has 4 equations. The second and fourth are conditional.
```
f.match_1.eq_2
(motive : S → S → Sort u) (a b c : S)
(h_1 : (a : Nat) → (c : S) → motive (S.mk1 a) c)
(h_2 : (a b c : S) → motive a (S.mk2 1 (b.mk4 c)))
(h_3 : (a : Bool) → (b c : S) → motive (S.mk3 a) (b.mk4 c))
(h_4 : (a b : S) → motive a b) :
(∀ (a_1 : Nat), a = S.mk1 a_1 → False) → -- <<< Condition stating it is not the first case
f.match_1 motive a (S.mk2 1 (b.mk4 c)) h_1 h_2 h_3 h_4 = h_2 a b c
f.match_1.eq_4
(motive : S → S → Sort u) (a b : S)
(h_1 : (a : Nat) → (c : S) → motive (S.mk1 a) c)
(h_2 : (a b c : S) → motive a (S.mk2 1 (b.mk4 c)))
(h_3 : (a : Bool) → (b c : S) → motive (S.mk3 a) (b.mk4 c))
(h_4 : (a b : S) → motive a b) :
(∀ (a_1 : Nat), a = S.mk1 a_1 → False) → -- <<< Condition stating it is not the first case
(∀ (b_1 c : S), b = S.mk2 1 (b_1.mk4 c) → False) → -- <<< Condition stating it is not the second case
(∀ (a_1 : Bool) (b_1 c : S), a = S.mk3 a_1 → b = b_1.mk4 c → False) → -- -- <<< Condition stating it is not the third case
f.match_1 motive a b h_1 h_2 h_3 h_4 = h_4 a b
```
In the two equational theorems above, we have the following conditions.
```
- `(∀ (a_1 : Nat), a = S.mk1 a_1 → False)`
- `(∀ (b_1 c : S), b = S.mk2 1 (b_1.mk4 c) → False)`
- `(∀ (a_1 : Bool) (b_1 c : S), a = S.mk3 a_1 → b = b_1.mk4 c → False)`
```
When instantiating the equations (and `match`-splitter), we wrap the conditions with the gadget `Grind.MatchCond`.
This gadget is used for implementing truth-value propagation. See the propagator `propagateMatchCond` below.
For example, given a condition `C` of the form `Grind.MatchCond (∀ (a : Nat), t = S.mk1 a → False)`,
if `t` is merged with an equivalence class containing `S.mk2 n s`, then `C` is asseted to `true` by `propagateMatchCond`.
This module also provides auxiliary functions for detecting congruences between `match`-expression conditions.
See function `collectMatchCondLhssAndAbstract`.
Remark: This note highlights that the representation used for encoding `match`-expressions with
overlapping patterns is far from ideal for the `grind` module which operates with equivalence classes
and does not perform substitutions like `simp`. While modifying how `match`-expressions are encoded in Lean
would require major refactoring and affect many modules, this issue is important to acknowledge.
A different representation could simplify `grind`, but it could add extra complexity to other modules.
-/
/--
Returns `Grind.MatchCond e`.
@ -17,8 +78,8 @@ Recall that `Grind.MatchCond` is an identity function,
but the following simproc is used to prevent the term `e` from being simplified,
and we have special support for propagating is truth value.
-/
def markAsMatchCond (e : Expr) : MetaM Expr :=
mkAppM ``Grind.MatchCond #[e]
def markAsMatchCond (e : Expr) : Expr :=
mkApp (mkConst ``Grind.MatchCond) e
builtin_dsimproc_decl reduceMatchCond (Grind.MatchCond _) := fun e => do
let_expr Grind.MatchCond _ ← e | return .continue
@ -28,15 +89,103 @@ builtin_dsimproc_decl reduceMatchCond (Grind.MatchCond _) := fun e => do
def addMatchCond (s : Simprocs) : CoreM Simprocs := do
s.add ``reduceMatchCond (post := false)
/--
Returns `some (lhs, rhs, isHEq)` if `e` is of the form
- `Eq _ lhs rhs` (`isHEq := false`), or
- `HEq _ lhs _ rhs` (`isHEq := true`)
-/
private def isEqHEq? (e : Expr) : Option (Expr × Expr × Bool) :=
match_expr e with
| Eq _ lhs rhs => some (lhs, rhs, false)
| HEq _ lhs _ rhs => some (lhs, rhs, true)
| _ => none
/--
Given `e` a `match`-expression condition, returns the left-hand side
of the ground equations.
-/
private def collectMatchCondLhss (e : Expr) : Array Expr := Id.run do
let mut r := #[]
let mut e := e
repeat
let .forallE _ d b _ := e | return r
if let some (lhs, _, _) := isEqHEq? d then
unless lhs.hasLooseBVars do
r := r.push lhs
e := b
return r
/--
Replaces the left-hand side of an equality (or heterogeneous equality) `e` with `lhsNew`.
-/
private def replaceLhs? (e : Expr) (lhsNew : Expr) : Option Expr :=
match_expr e with
| f@Eq α lhs rhs => if lhs.hasLooseBVars then none else some (mkApp3 f α lhsNew rhs)
| f@HEq α lhs β rhs => if lhs.hasLooseBVars then none else some (mkApp4 f α lhsNew β rhs)
| _ => none
/--
Given `e` a `match`-expression condition, returns the left-hand side
of the ground equations, **and** function application that abstracts the left-hand sides.
As an example, assume we have a `match`-expression condition `C₁` of the form
```
Grind.MatchCond (∀ y₁ y₂ y₃, t = .mk₁ y₁ → s = .mk₂ y₂ y₃ → False)
```
then the result returned by this function is
```
(#[t, s], (fun x₁ x₂ => (∀ y₁ y₂ y₃, x₁ = .mk₁ y₁ → x₂ = .mk₂ y₂ y₃ → False)) t s)
```
Note that the returned expression is definitionally equal to `C₁`.
We use this expression to detect whether two different `match`-expression conditions are
congruent.
For example, suppose we also have the `match`-expression `C₂` of the form
```
Grind.MatchCond (∀ y₁ y₂ y₃, a = .mk₁ y₁ → b = .mk₂ y₂ y₃ → False)
```
This function would return
```
(#[a, b], (fun x₁ x₂ => (∀ y₁ y₂ y₃, x₁ = .mk₁ y₁ → x₂ = .mk₂ y₂ y₃ → False)) a b)
```
Note that the lambda abstraction is identical to the first one. Let's call it `l`.
Thus, we can write the two pairs above as
- `(#[t, s], l t s)`
- `(#[a, b], l a b)`
Moreover, `C₁` is definitionally equal to `l t s`, and `C₂` is definitionally equal to `l a b`.
Then, if `grind` infers that `t = a` and `s = b`, it will detect that `l t s` and `l a b` are
equal by congruence, and consequently `C₁` is equal to `C₂`.
-/
def collectMatchCondLhssAndAbstract (matchCond : Expr) : GoalM (Array Expr × Expr) := do
let_expr Grind.MatchCond e := matchCond | return (#[], matchCond)
let lhss := collectMatchCondLhss e
let rec go (i : Nat) (xs : Array Expr) : GoalM Expr := do
if h : i < lhss.size then
let lhs := lhss[i]
withLocalDeclD ((`x).appendIndexAfter i) (← inferType lhs) fun x =>
go (i+1) (xs.push x)
else
let rec replaceLhss (e : Expr) (i : Nat) : Expr := Id.run do
let .forallE _ d b _ := e | return e
if h : i < xs.size then
if let some dNew := replaceLhs? d xs[i] then
return e.updateForallE! dNew (replaceLhss b (i+1))
else
return e.updateForallE! d (replaceLhss b i)
else
return e
let eAbst := replaceLhss e 0
let eLam ← mkLambdaFVars xs eAbst
let e' := mkAppN eLam lhss
shareCommon e'
let e' ← go 0 #[]
return (lhss, e')
/--
Helper function for `isSatisfied`.
See `isSatisfied`.
-/
private partial def isMathCondFalseHyp (e : Expr) : GoalM Bool := do
match_expr e with
| Eq _ lhs rhs => isFalse lhs rhs
| HEq _ lhs _ rhs => isFalse lhs rhs
| _ => return false
private partial def isMatchCondFalseHyp (e : Expr) : GoalM Bool := do
let some (lhs, rhs, _) := isEqHEq? e | return false
isFalse lhs rhs
where
isFalse (lhs rhs : Expr) : GoalM Bool := do
if lhs.hasLooseBVars then return false
@ -91,18 +240,19 @@ private partial def isStatisfied (e : Expr) : GoalM Bool := do
let mut e := e
repeat
let .forallE _ d b _ := e | break
if (← isMathCondFalseHyp d) then
if (← isMatchCondFalseHyp d) then
trace[grind.debug.matchCond] "satifised{indentExpr e}\nthe following equality is false{indentExpr d}"
return true
e := b
return false
private partial def mkMathCondProof? (e : Expr) : GoalM (Option Expr) := do
/-- Constructs a proof for a satisfied `match`-expression condition. -/
private partial def mkMatchCondProof? (e : Expr) : GoalM (Option Expr) := do
let_expr Grind.MatchCond f ← e | return none
forallTelescopeReducing f fun xs _ => do
for x in xs do
let type ← inferType x
if (← isMathCondFalseHyp type) then
if (← isMatchCondFalseHyp type) then
trace[grind.debug.matchCond] ">>> {type}"
let some h ← go? x | pure ()
return some (← mkLambdaFVars xs h)
@ -110,10 +260,8 @@ private partial def mkMathCondProof? (e : Expr) : GoalM (Option Expr) := do
where
go? (h : Expr) : GoalM (Option Expr) := do
trace[grind.debug.matchCond] "go?: {← inferType h}"
let (lhs, rhs, isHeq) ← match_expr (← inferType h) with
| Eq _ lhs rhs => pure (lhs, rhs, false)
| HEq _ lhs _ rhs => pure (lhs, rhs, true)
| _ => return none
let some (lhs, rhs, isHeq) := isEqHEq? (← inferType h)
| return none
let target ← (← get).mvarId.getType
let root ← getRootENode lhs
let h ← if isHeq then
@ -147,7 +295,7 @@ where
builtin_grind_propagator propagateMatchCond ↑Grind.MatchCond := fun e => do
trace[grind.debug.matchCond] "visiting{indentExpr e}"
if !(← isStatisfied e) then return ()
let some h ← mkMathCondProof? e
let some h ← mkMatchCondProof? e
| reportIssue m!"failed to construct proof for{indentExpr e}"; return ()
trace[grind.debug.matchCond] "{← inferType h}"
pushEqTrue e <| mkEqTrueCore e h

View file

@ -12,36 +12,47 @@ def f (x y : S) :=
| _, _ => 5
example : f a b < 2 → b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → False := by
unfold f
grind (splits := 0)
grind (splits := 0) [f.eq_def]
example : b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → f a b = 5 := by
unfold f
grind (splits := 0)
grind (splits := 0) [f.eq_def]
example : b = .mk2 y1 y2 → y1 = 2 → a = .mk3 n → f a b = 4 := by
unfold f
grind (splits := 0)
grind (splits := 0) [f.eq_def]
example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk3 n → f a b = 3 := by
unfold f
grind (splits := 0)
grind (splits := 0) [f.eq_def]
example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk2 s3 s4 → f a b = 3 := by
unfold f
grind (splits := 0)
grind (splits := 0) [f.eq_def]
example : f a b > 1 := by
grind (splits := 1) [f.eq_def]
example : f a b > 1 := by
grind [f.eq_def]
def g (x y : S) :=
match x, y with
| .mk1 a, _ => a + 2
| _, .mk2 1 (.mk4 _ _) => 3
| .mk3 _, .mk4 _ _ => 4
| _, _ => 5
example : g a b > 1 := by
grind [g.eq_def]
inductive Vec (α : Type u) : Nat → Type u
| nil : Vec α 0
| cons : α → Vec α n → Vec α (n+1)
def g (v w : Vec α n) : Nat :=
def h (v w : Vec α n) : Nat :=
match v, w with
| _, .cons _ (.cons _ _) => 20
| .nil, _ => 30
| _, _ => 40
-- TODO: introduce casts while instantiating equation theorems for `g.match_1`
-- example (a b : Vec α 2) : g a b = 20 := by
-- unfold g
-- TODO: introduce casts while instantiating equation theorems for `h.match_1`
-- example (a b : Vec α 2) : h a b = 20 := by
-- unfold h
-- grind