feat: case splitting match-expressions with overlapping patterns in grind (#6735)
This PR adds support for case splitting on `match`-expressions with overlapping patterns to the `grind` tactic. `grind` can now solve examples such as: ``` inductive S where | mk1 (n : Nat) | mk2 (n : Nat) (s : S) | mk3 (n : Bool) | mk4 (s1 s2 : S) def g (x y : S) := match x, y with | .mk1 a, _ => a + 2 | _, .mk2 1 (.mk4 _ _) => 3 | .mk3 _, .mk4 _ _ => 4 | _, _ => 5 example : g a b > 1 := by grind [g.eq_def] ```
This commit is contained in:
parent
3881f21df1
commit
de31faa470
6 changed files with 234 additions and 45 deletions
|
|
@ -7,9 +7,43 @@ prelude
|
|||
import Lean.Meta.Tactic.Util
|
||||
import Lean.Meta.Tactic.Cases
|
||||
import Lean.Meta.Match.MatcherApp
|
||||
import Lean.Meta.Tactic.Grind.MatchCond
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/--
|
||||
Returns `true` if `e` is of the form `∀ ..., _ = _ ... -> False`
|
||||
-/
|
||||
private def isMatchCond (e : Expr) : Bool := Id.run do
|
||||
let mut e := e
|
||||
let mut hasEqs := false
|
||||
repeat
|
||||
let .forallE _ d b _ := e | return false
|
||||
if d.isEq || d.isHEq then hasEqs := true
|
||||
if b.isFalse then return hasEqs
|
||||
e := b
|
||||
return true
|
||||
|
||||
/--
|
||||
Given a splitter alternative, annotate the terms that are `match`-expression
|
||||
conditions corresponding to overlapping patterns.
|
||||
-/
|
||||
private def addMatchCondsToAlt (alt : Expr) : Expr := Id.run do
|
||||
let .forallE _ d b _ := alt
|
||||
| return alt
|
||||
let d := if isMatchCond d then markAsMatchCond d else d
|
||||
return alt.updateForallE! d (addMatchCondsToAlt b)
|
||||
|
||||
/--
|
||||
Annotates the `match`-expression conditions in the alternatives in the given
|
||||
`match` splitter type.
|
||||
-/
|
||||
private def addMatchCondsToSplitter (splitterType : Expr) (numAlts : Nat) : Expr := Id.run do
|
||||
if numAlts == 0 then return splitterType
|
||||
let .forallE _ alt b _ := splitterType
|
||||
| return splitterType
|
||||
return splitterType.updateForallE! (addMatchCondsToAlt alt) (addMatchCondsToSplitter b (numAlts-1))
|
||||
|
||||
def casesMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
|
||||
let some app ← matchMatcherApp? e
|
||||
| throwTacticEx `grind.casesMatch mvarId m!"`match`-expression expected{indentExpr e}"
|
||||
|
|
@ -23,7 +57,10 @@ def casesMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.with
|
|||
let splitterApp := mkAppN splitterApp app.params
|
||||
let splitterApp := mkApp splitterApp motive
|
||||
let splitterApp := mkAppN splitterApp app.discrs
|
||||
let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType splitterApp) app.alts.size (kind := .syntheticOpaque)
|
||||
let numAlts := app.alts.size
|
||||
let splitterType ← inferType splitterApp
|
||||
let splitterType := addMatchCondsToSplitter splitterType app.alts.size
|
||||
let (mvars, _, _) ← forallMetaBoundedTelescope splitterType numAlts (kind := .syntheticOpaque)
|
||||
let splitterApp := mkAppN splitterApp mvars
|
||||
let val := mkAppN splitterApp eqRefls
|
||||
mvarId.assign val
|
||||
|
|
|
|||
|
|
@ -268,6 +268,7 @@ def addNewEq (lhs rhs proof : Expr) (generation : Nat) : GoalM Unit := do
|
|||
|
||||
/-- Adds a new `fact` justified by the given proof and using the given generation. -/
|
||||
def add (fact : Expr) (proof : Expr) (generation := 0) : GoalM Unit := do
|
||||
if fact.isTrue then return ()
|
||||
storeFact fact
|
||||
trace_goal[grind.assert] "{fact}"
|
||||
if (← isInconsistent) then return ()
|
||||
|
|
|
|||
|
|
@ -216,10 +216,10 @@ Helper function for marking parts of `match`-equation theorem as "do-not-simplif
|
|||
-/
|
||||
private partial def annotateMatchEqnType (prop : Expr) (initApp : Expr) : M Expr := do
|
||||
if let .forallE n d b bi := prop then
|
||||
let d ← if (← isProp d) then
|
||||
let d := if (← isProp d) then
|
||||
markAsMatchCond d
|
||||
else
|
||||
pure d
|
||||
d
|
||||
withLocalDecl n bi d fun x => do
|
||||
mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x) initApp)
|
||||
else
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import Lean.Meta.Tactic.Grind.Types
|
|||
import Lean.Meta.Tactic.Grind.Util
|
||||
import Lean.Meta.Tactic.Grind.Canon
|
||||
import Lean.Meta.Tactic.Grind.Beta
|
||||
import Lean.Meta.Tactic.Grind.MatchCond
|
||||
import Lean.Meta.Tactic.Grind.Arith.Internalize
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
|
@ -127,21 +128,12 @@ private partial def internalizePattern (pattern : Expr) (generation : Nat) : Goa
|
|||
|
||||
/-- Internalizes the `MatchCond` gadget. -/
|
||||
private partial def internalizeMatchCond (matchCond : Expr) (generation : Nat) : GoalM Unit := do
|
||||
let_expr Grind.MatchCond e ← matchCond | return ()
|
||||
mkENode' matchCond generation
|
||||
let mut e := e
|
||||
repeat
|
||||
let .forallE _ d b _ := e | break
|
||||
let internalizeLhs (lhs : Expr) : GoalM Unit := do
|
||||
unless lhs.hasLooseBVars do
|
||||
internalize lhs generation
|
||||
registerParent matchCond lhs
|
||||
match_expr d with
|
||||
| Eq _ lhs _ => internalizeLhs lhs
|
||||
| HEq _ lhs _ _ => internalizeLhs lhs
|
||||
| _ => pure ()
|
||||
e := b
|
||||
let (lhss, e') ← collectMatchCondLhssAndAbstract matchCond
|
||||
lhss.forM fun lhs => do internalize lhs generation; registerParent matchCond lhs
|
||||
propagateUp matchCond
|
||||
internalize e' generation
|
||||
pushEq matchCond e' (← mkEqRefl matchCond)
|
||||
|
||||
partial def activateTheorem (thm : EMatchTheorem) (generation : Nat) : GoalM Unit := do
|
||||
-- Recall that we use the proof as part of the key for a set of instances found so far.
|
||||
|
|
|
|||
|
|
@ -10,6 +10,67 @@ import Lean.Meta.Tactic.Simp.Simproc
|
|||
import Lean.Meta.Tactic.Grind.PropagatorAttr
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
/-!
|
||||
Support for `match`-expressions with overlapping patterns.
|
||||
Recall that when a `match`-expression has overlapping patterns, some of its equation theorems are
|
||||
conditional. Let's consider the following example
|
||||
```
|
||||
inductive S where
|
||||
| mk1 (n : Nat)
|
||||
| mk2 (n : Nat) (s : S)
|
||||
| mk3 (n : Bool)
|
||||
| mk4 (s1 s2 : S)
|
||||
|
||||
def f (x y : S) :=
|
||||
match x, y with
|
||||
| .mk1 a, c => a + 2
|
||||
| a, .mk2 1 (.mk4 b c) => 3
|
||||
| .mk3 a, .mk4 b c => 4
|
||||
| a, b => 5
|
||||
```
|
||||
|
||||
The `match`-expression in this example has 4 equations. The second and fourth are conditional.
|
||||
```
|
||||
f.match_1.eq_2
|
||||
(motive : S → S → Sort u) (a b c : S)
|
||||
(h_1 : (a : Nat) → (c : S) → motive (S.mk1 a) c)
|
||||
(h_2 : (a b c : S) → motive a (S.mk2 1 (b.mk4 c)))
|
||||
(h_3 : (a : Bool) → (b c : S) → motive (S.mk3 a) (b.mk4 c))
|
||||
(h_4 : (a b : S) → motive a b) :
|
||||
(∀ (a_1 : Nat), a = S.mk1 a_1 → False) → -- <<< Condition stating it is not the first case
|
||||
f.match_1 motive a (S.mk2 1 (b.mk4 c)) h_1 h_2 h_3 h_4 = h_2 a b c
|
||||
|
||||
f.match_1.eq_4
|
||||
(motive : S → S → Sort u) (a b : S)
|
||||
(h_1 : (a : Nat) → (c : S) → motive (S.mk1 a) c)
|
||||
(h_2 : (a b c : S) → motive a (S.mk2 1 (b.mk4 c)))
|
||||
(h_3 : (a : Bool) → (b c : S) → motive (S.mk3 a) (b.mk4 c))
|
||||
(h_4 : (a b : S) → motive a b) :
|
||||
(∀ (a_1 : Nat), a = S.mk1 a_1 → False) → -- <<< Condition stating it is not the first case
|
||||
(∀ (b_1 c : S), b = S.mk2 1 (b_1.mk4 c) → False) → -- <<< Condition stating it is not the second case
|
||||
(∀ (a_1 : Bool) (b_1 c : S), a = S.mk3 a_1 → b = b_1.mk4 c → False) → -- -- <<< Condition stating it is not the third case
|
||||
f.match_1 motive a b h_1 h_2 h_3 h_4 = h_4 a b
|
||||
```
|
||||
In the two equational theorems above, we have the following conditions.
|
||||
```
|
||||
- `(∀ (a_1 : Nat), a = S.mk1 a_1 → False)`
|
||||
- `(∀ (b_1 c : S), b = S.mk2 1 (b_1.mk4 c) → False)`
|
||||
- `(∀ (a_1 : Bool) (b_1 c : S), a = S.mk3 a_1 → b = b_1.mk4 c → False)`
|
||||
```
|
||||
When instantiating the equations (and `match`-splitter), we wrap the conditions with the gadget `Grind.MatchCond`.
|
||||
This gadget is used for implementing truth-value propagation. See the propagator `propagateMatchCond` below.
|
||||
For example, given a condition `C` of the form `Grind.MatchCond (∀ (a : Nat), t = S.mk1 a → False)`,
|
||||
if `t` is merged with an equivalence class containing `S.mk2 n s`, then `C` is asseted to `true` by `propagateMatchCond`.
|
||||
|
||||
This module also provides auxiliary functions for detecting congruences between `match`-expression conditions.
|
||||
See function `collectMatchCondLhssAndAbstract`.
|
||||
|
||||
Remark: This note highlights that the representation used for encoding `match`-expressions with
|
||||
overlapping patterns is far from ideal for the `grind` module which operates with equivalence classes
|
||||
and does not perform substitutions like `simp`. While modifying how `match`-expressions are encoded in Lean
|
||||
would require major refactoring and affect many modules, this issue is important to acknowledge.
|
||||
A different representation could simplify `grind`, but it could add extra complexity to other modules.
|
||||
-/
|
||||
|
||||
/--
|
||||
Returns `Grind.MatchCond e`.
|
||||
|
|
@ -17,8 +78,8 @@ Recall that `Grind.MatchCond` is an identity function,
|
|||
but the following simproc is used to prevent the term `e` from being simplified,
|
||||
and we have special support for propagating is truth value.
|
||||
-/
|
||||
def markAsMatchCond (e : Expr) : MetaM Expr :=
|
||||
mkAppM ``Grind.MatchCond #[e]
|
||||
def markAsMatchCond (e : Expr) : Expr :=
|
||||
mkApp (mkConst ``Grind.MatchCond) e
|
||||
|
||||
builtin_dsimproc_decl reduceMatchCond (Grind.MatchCond _) := fun e => do
|
||||
let_expr Grind.MatchCond _ ← e | return .continue
|
||||
|
|
@ -28,15 +89,103 @@ builtin_dsimproc_decl reduceMatchCond (Grind.MatchCond _) := fun e => do
|
|||
def addMatchCond (s : Simprocs) : CoreM Simprocs := do
|
||||
s.add ``reduceMatchCond (post := false)
|
||||
|
||||
/--
|
||||
Returns `some (lhs, rhs, isHEq)` if `e` is of the form
|
||||
- `Eq _ lhs rhs` (`isHEq := false`), or
|
||||
- `HEq _ lhs _ rhs` (`isHEq := true`)
|
||||
-/
|
||||
private def isEqHEq? (e : Expr) : Option (Expr × Expr × Bool) :=
|
||||
match_expr e with
|
||||
| Eq _ lhs rhs => some (lhs, rhs, false)
|
||||
| HEq _ lhs _ rhs => some (lhs, rhs, true)
|
||||
| _ => none
|
||||
|
||||
/--
|
||||
Given `e` a `match`-expression condition, returns the left-hand side
|
||||
of the ground equations.
|
||||
-/
|
||||
private def collectMatchCondLhss (e : Expr) : Array Expr := Id.run do
|
||||
let mut r := #[]
|
||||
let mut e := e
|
||||
repeat
|
||||
let .forallE _ d b _ := e | return r
|
||||
if let some (lhs, _, _) := isEqHEq? d then
|
||||
unless lhs.hasLooseBVars do
|
||||
r := r.push lhs
|
||||
e := b
|
||||
return r
|
||||
|
||||
/--
|
||||
Replaces the left-hand side of an equality (or heterogeneous equality) `e` with `lhsNew`.
|
||||
-/
|
||||
private def replaceLhs? (e : Expr) (lhsNew : Expr) : Option Expr :=
|
||||
match_expr e with
|
||||
| f@Eq α lhs rhs => if lhs.hasLooseBVars then none else some (mkApp3 f α lhsNew rhs)
|
||||
| f@HEq α lhs β rhs => if lhs.hasLooseBVars then none else some (mkApp4 f α lhsNew β rhs)
|
||||
| _ => none
|
||||
|
||||
/--
|
||||
Given `e` a `match`-expression condition, returns the left-hand side
|
||||
of the ground equations, **and** function application that abstracts the left-hand sides.
|
||||
As an example, assume we have a `match`-expression condition `C₁` of the form
|
||||
```
|
||||
Grind.MatchCond (∀ y₁ y₂ y₃, t = .mk₁ y₁ → s = .mk₂ y₂ y₃ → False)
|
||||
```
|
||||
then the result returned by this function is
|
||||
```
|
||||
(#[t, s], (fun x₁ x₂ => (∀ y₁ y₂ y₃, x₁ = .mk₁ y₁ → x₂ = .mk₂ y₂ y₃ → False)) t s)
|
||||
```
|
||||
Note that the returned expression is definitionally equal to `C₁`.
|
||||
We use this expression to detect whether two different `match`-expression conditions are
|
||||
congruent.
|
||||
For example, suppose we also have the `match`-expression `C₂` of the form
|
||||
```
|
||||
Grind.MatchCond (∀ y₁ y₂ y₃, a = .mk₁ y₁ → b = .mk₂ y₂ y₃ → False)
|
||||
```
|
||||
This function would return
|
||||
```
|
||||
(#[a, b], (fun x₁ x₂ => (∀ y₁ y₂ y₃, x₁ = .mk₁ y₁ → x₂ = .mk₂ y₂ y₃ → False)) a b)
|
||||
```
|
||||
Note that the lambda abstraction is identical to the first one. Let's call it `l`.
|
||||
Thus, we can write the two pairs above as
|
||||
- `(#[t, s], l t s)`
|
||||
- `(#[a, b], l a b)`
|
||||
Moreover, `C₁` is definitionally equal to `l t s`, and `C₂` is definitionally equal to `l a b`.
|
||||
Then, if `grind` infers that `t = a` and `s = b`, it will detect that `l t s` and `l a b` are
|
||||
equal by congruence, and consequently `C₁` is equal to `C₂`.
|
||||
-/
|
||||
def collectMatchCondLhssAndAbstract (matchCond : Expr) : GoalM (Array Expr × Expr) := do
|
||||
let_expr Grind.MatchCond e := matchCond | return (#[], matchCond)
|
||||
let lhss := collectMatchCondLhss e
|
||||
let rec go (i : Nat) (xs : Array Expr) : GoalM Expr := do
|
||||
if h : i < lhss.size then
|
||||
let lhs := lhss[i]
|
||||
withLocalDeclD ((`x).appendIndexAfter i) (← inferType lhs) fun x =>
|
||||
go (i+1) (xs.push x)
|
||||
else
|
||||
let rec replaceLhss (e : Expr) (i : Nat) : Expr := Id.run do
|
||||
let .forallE _ d b _ := e | return e
|
||||
if h : i < xs.size then
|
||||
if let some dNew := replaceLhs? d xs[i] then
|
||||
return e.updateForallE! dNew (replaceLhss b (i+1))
|
||||
else
|
||||
return e.updateForallE! d (replaceLhss b i)
|
||||
else
|
||||
return e
|
||||
let eAbst := replaceLhss e 0
|
||||
let eLam ← mkLambdaFVars xs eAbst
|
||||
let e' := mkAppN eLam lhss
|
||||
shareCommon e'
|
||||
let e' ← go 0 #[]
|
||||
return (lhss, e')
|
||||
|
||||
/--
|
||||
Helper function for `isSatisfied`.
|
||||
See `isSatisfied`.
|
||||
-/
|
||||
private partial def isMathCondFalseHyp (e : Expr) : GoalM Bool := do
|
||||
match_expr e with
|
||||
| Eq _ lhs rhs => isFalse lhs rhs
|
||||
| HEq _ lhs _ rhs => isFalse lhs rhs
|
||||
| _ => return false
|
||||
private partial def isMatchCondFalseHyp (e : Expr) : GoalM Bool := do
|
||||
let some (lhs, rhs, _) := isEqHEq? e | return false
|
||||
isFalse lhs rhs
|
||||
where
|
||||
isFalse (lhs rhs : Expr) : GoalM Bool := do
|
||||
if lhs.hasLooseBVars then return false
|
||||
|
|
@ -91,18 +240,19 @@ private partial def isStatisfied (e : Expr) : GoalM Bool := do
|
|||
let mut e := e
|
||||
repeat
|
||||
let .forallE _ d b _ := e | break
|
||||
if (← isMathCondFalseHyp d) then
|
||||
if (← isMatchCondFalseHyp d) then
|
||||
trace[grind.debug.matchCond] "satifised{indentExpr e}\nthe following equality is false{indentExpr d}"
|
||||
return true
|
||||
e := b
|
||||
return false
|
||||
|
||||
private partial def mkMathCondProof? (e : Expr) : GoalM (Option Expr) := do
|
||||
/-- Constructs a proof for a satisfied `match`-expression condition. -/
|
||||
private partial def mkMatchCondProof? (e : Expr) : GoalM (Option Expr) := do
|
||||
let_expr Grind.MatchCond f ← e | return none
|
||||
forallTelescopeReducing f fun xs _ => do
|
||||
for x in xs do
|
||||
let type ← inferType x
|
||||
if (← isMathCondFalseHyp type) then
|
||||
if (← isMatchCondFalseHyp type) then
|
||||
trace[grind.debug.matchCond] ">>> {type}"
|
||||
let some h ← go? x | pure ()
|
||||
return some (← mkLambdaFVars xs h)
|
||||
|
|
@ -110,10 +260,8 @@ private partial def mkMathCondProof? (e : Expr) : GoalM (Option Expr) := do
|
|||
where
|
||||
go? (h : Expr) : GoalM (Option Expr) := do
|
||||
trace[grind.debug.matchCond] "go?: {← inferType h}"
|
||||
let (lhs, rhs, isHeq) ← match_expr (← inferType h) with
|
||||
| Eq _ lhs rhs => pure (lhs, rhs, false)
|
||||
| HEq _ lhs _ rhs => pure (lhs, rhs, true)
|
||||
| _ => return none
|
||||
let some (lhs, rhs, isHeq) := isEqHEq? (← inferType h)
|
||||
| return none
|
||||
let target ← (← get).mvarId.getType
|
||||
let root ← getRootENode lhs
|
||||
let h ← if isHeq then
|
||||
|
|
@ -147,7 +295,7 @@ where
|
|||
builtin_grind_propagator propagateMatchCond ↑Grind.MatchCond := fun e => do
|
||||
trace[grind.debug.matchCond] "visiting{indentExpr e}"
|
||||
if !(← isStatisfied e) then return ()
|
||||
let some h ← mkMathCondProof? e
|
||||
let some h ← mkMatchCondProof? e
|
||||
| reportIssue m!"failed to construct proof for{indentExpr e}"; return ()
|
||||
trace[grind.debug.matchCond] "{← inferType h}"
|
||||
pushEqTrue e <| mkEqTrueCore e h
|
||||
|
|
|
|||
|
|
@ -12,36 +12,47 @@ def f (x y : S) :=
|
|||
| _, _ => 5
|
||||
|
||||
example : f a b < 2 → b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → False := by
|
||||
unfold f
|
||||
grind (splits := 0)
|
||||
grind (splits := 0) [f.eq_def]
|
||||
|
||||
example : b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → f a b = 5 := by
|
||||
unfold f
|
||||
grind (splits := 0)
|
||||
grind (splits := 0) [f.eq_def]
|
||||
|
||||
example : b = .mk2 y1 y2 → y1 = 2 → a = .mk3 n → f a b = 4 := by
|
||||
unfold f
|
||||
grind (splits := 0)
|
||||
grind (splits := 0) [f.eq_def]
|
||||
|
||||
example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk3 n → f a b = 3 := by
|
||||
unfold f
|
||||
grind (splits := 0)
|
||||
grind (splits := 0) [f.eq_def]
|
||||
|
||||
example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk2 s3 s4 → f a b = 3 := by
|
||||
unfold f
|
||||
grind (splits := 0)
|
||||
grind (splits := 0) [f.eq_def]
|
||||
|
||||
example : f a b > 1 := by
|
||||
grind (splits := 1) [f.eq_def]
|
||||
|
||||
example : f a b > 1 := by
|
||||
grind [f.eq_def]
|
||||
|
||||
def g (x y : S) :=
|
||||
match x, y with
|
||||
| .mk1 a, _ => a + 2
|
||||
| _, .mk2 1 (.mk4 _ _) => 3
|
||||
| .mk3 _, .mk4 _ _ => 4
|
||||
| _, _ => 5
|
||||
|
||||
example : g a b > 1 := by
|
||||
grind [g.eq_def]
|
||||
|
||||
inductive Vec (α : Type u) : Nat → Type u
|
||||
| nil : Vec α 0
|
||||
| cons : α → Vec α n → Vec α (n+1)
|
||||
|
||||
def g (v w : Vec α n) : Nat :=
|
||||
def h (v w : Vec α n) : Nat :=
|
||||
match v, w with
|
||||
| _, .cons _ (.cons _ _) => 20
|
||||
| .nil, _ => 30
|
||||
| _, _ => 40
|
||||
|
||||
-- TODO: introduce casts while instantiating equation theorems for `g.match_1`
|
||||
-- example (a b : Vec α 2) : g a b = 20 := by
|
||||
-- unfold g
|
||||
-- TODO: introduce casts while instantiating equation theorems for `h.match_1`
|
||||
-- example (a b : Vec α 2) : h a b = 20 := by
|
||||
-- unfold h
|
||||
-- grind
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue