From de31faa470f2da42c0346256fd4a8b8c86b4b264 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 21 Jan 2025 18:59:42 -0800 Subject: [PATCH] feat: case splitting `match`-expressions with overlapping patterns in `grind` (#6735) This PR adds support for case splitting on `match`-expressions with overlapping patterns to the `grind` tactic. `grind` can now solve examples such as: ``` inductive S where | mk1 (n : Nat) | mk2 (n : Nat) (s : S) | mk3 (n : Bool) | mk4 (s1 s2 : S) def g (x y : S) := match x, y with | .mk1 a, _ => a + 2 | _, .mk2 1 (.mk4 _ _) => 3 | .mk3 _, .mk4 _ _ => 4 | _, _ => 5 example : g a b > 1 := by grind [g.eq_def] ``` --- src/Lean/Meta/Tactic/Grind/CasesMatch.lean | 39 +++- src/Lean/Meta/Tactic/Grind/Core.lean | 1 + src/Lean/Meta/Tactic/Grind/EMatch.lean | 4 +- src/Lean/Meta/Tactic/Grind/Internalize.lean | 18 +- src/Lean/Meta/Tactic/Grind/MatchCond.lean | 178 ++++++++++++++++-- .../lean/run/grind_match_eq_propagation.lean | 39 ++-- 6 files changed, 234 insertions(+), 45 deletions(-) diff --git a/src/Lean/Meta/Tactic/Grind/CasesMatch.lean b/src/Lean/Meta/Tactic/Grind/CasesMatch.lean index 1e5f07ed6a..0244bb1eef 100644 --- a/src/Lean/Meta/Tactic/Grind/CasesMatch.lean +++ b/src/Lean/Meta/Tactic/Grind/CasesMatch.lean @@ -7,9 +7,43 @@ prelude import Lean.Meta.Tactic.Util import Lean.Meta.Tactic.Cases import Lean.Meta.Match.MatcherApp +import Lean.Meta.Tactic.Grind.MatchCond namespace Lean.Meta.Grind +/-- +Returns `true` if `e` is of the form `∀ ..., _ = _ ... -> False` +-/ +private def isMatchCond (e : Expr) : Bool := Id.run do + let mut e := e + let mut hasEqs := false + repeat + let .forallE _ d b _ := e | return false + if d.isEq || d.isHEq then hasEqs := true + if b.isFalse then return hasEqs + e := b + return true + +/-- +Given a splitter alternative, annotate the terms that are `match`-expression +conditions corresponding to overlapping patterns. +-/ +private def addMatchCondsToAlt (alt : Expr) : Expr := Id.run do + let .forallE _ d b _ := alt + | return alt + let d := if isMatchCond d then markAsMatchCond d else d + return alt.updateForallE! d (addMatchCondsToAlt b) + +/-- +Annotates the `match`-expression conditions in the alternatives in the given +`match` splitter type. +-/ +private def addMatchCondsToSplitter (splitterType : Expr) (numAlts : Nat) : Expr := Id.run do + if numAlts == 0 then return splitterType + let .forallE _ alt b _ := splitterType + | return splitterType + return splitterType.updateForallE! (addMatchCondsToAlt alt) (addMatchCondsToSplitter b (numAlts-1)) + def casesMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do let some app ← matchMatcherApp? e | throwTacticEx `grind.casesMatch mvarId m!"`match`-expression expected{indentExpr e}" @@ -23,7 +57,10 @@ def casesMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.with let splitterApp := mkAppN splitterApp app.params let splitterApp := mkApp splitterApp motive let splitterApp := mkAppN splitterApp app.discrs - let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType splitterApp) app.alts.size (kind := .syntheticOpaque) + let numAlts := app.alts.size + let splitterType ← inferType splitterApp + let splitterType := addMatchCondsToSplitter splitterType app.alts.size + let (mvars, _, _) ← forallMetaBoundedTelescope splitterType numAlts (kind := .syntheticOpaque) let splitterApp := mkAppN splitterApp mvars let val := mkAppN splitterApp eqRefls mvarId.assign val diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index e5158d2de9..1ab00bbdef 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -268,6 +268,7 @@ def addNewEq (lhs rhs proof : Expr) (generation : Nat) : GoalM Unit := do /-- Adds a new `fact` justified by the given proof and using the given generation. -/ def add (fact : Expr) (proof : Expr) (generation := 0) : GoalM Unit := do + if fact.isTrue then return () storeFact fact trace_goal[grind.assert] "{fact}" if (← isInconsistent) then return () diff --git a/src/Lean/Meta/Tactic/Grind/EMatch.lean b/src/Lean/Meta/Tactic/Grind/EMatch.lean index bef989aaa6..257bfe453e 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatch.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatch.lean @@ -216,10 +216,10 @@ Helper function for marking parts of `match`-equation theorem as "do-not-simplif -/ private partial def annotateMatchEqnType (prop : Expr) (initApp : Expr) : M Expr := do if let .forallE n d b bi := prop then - let d ← if (← isProp d) then + let d := if (← isProp d) then markAsMatchCond d else - pure d + d withLocalDecl n bi d fun x => do mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x) initApp) else diff --git a/src/Lean/Meta/Tactic/Grind/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Internalize.lean index e4ce20405d..1e40f7fc6d 100644 --- a/src/Lean/Meta/Tactic/Grind/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Internalize.lean @@ -13,6 +13,7 @@ import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Util import Lean.Meta.Tactic.Grind.Canon import Lean.Meta.Tactic.Grind.Beta +import Lean.Meta.Tactic.Grind.MatchCond import Lean.Meta.Tactic.Grind.Arith.Internalize namespace Lean.Meta.Grind @@ -127,21 +128,12 @@ private partial def internalizePattern (pattern : Expr) (generation : Nat) : Goa /-- Internalizes the `MatchCond` gadget. -/ private partial def internalizeMatchCond (matchCond : Expr) (generation : Nat) : GoalM Unit := do - let_expr Grind.MatchCond e ← matchCond | return () mkENode' matchCond generation - let mut e := e - repeat - let .forallE _ d b _ := e | break - let internalizeLhs (lhs : Expr) : GoalM Unit := do - unless lhs.hasLooseBVars do - internalize lhs generation - registerParent matchCond lhs - match_expr d with - | Eq _ lhs _ => internalizeLhs lhs - | HEq _ lhs _ _ => internalizeLhs lhs - | _ => pure () - e := b + let (lhss, e') ← collectMatchCondLhssAndAbstract matchCond + lhss.forM fun lhs => do internalize lhs generation; registerParent matchCond lhs propagateUp matchCond + internalize e' generation + pushEq matchCond e' (← mkEqRefl matchCond) partial def activateTheorem (thm : EMatchTheorem) (generation : Nat) : GoalM Unit := do -- Recall that we use the proof as part of the key for a set of instances found so far. diff --git a/src/Lean/Meta/Tactic/Grind/MatchCond.lean b/src/Lean/Meta/Tactic/Grind/MatchCond.lean index da3b990c7c..e37281b82d 100644 --- a/src/Lean/Meta/Tactic/Grind/MatchCond.lean +++ b/src/Lean/Meta/Tactic/Grind/MatchCond.lean @@ -10,6 +10,67 @@ import Lean.Meta.Tactic.Simp.Simproc import Lean.Meta.Tactic.Grind.PropagatorAttr namespace Lean.Meta.Grind +/-! +Support for `match`-expressions with overlapping patterns. +Recall that when a `match`-expression has overlapping patterns, some of its equation theorems are +conditional. Let's consider the following example +``` +inductive S where + | mk1 (n : Nat) + | mk2 (n : Nat) (s : S) + | mk3 (n : Bool) + | mk4 (s1 s2 : S) + +def f (x y : S) := + match x, y with + | .mk1 a, c => a + 2 + | a, .mk2 1 (.mk4 b c) => 3 + | .mk3 a, .mk4 b c => 4 + | a, b => 5 +``` + +The `match`-expression in this example has 4 equations. The second and fourth are conditional. +``` +f.match_1.eq_2 + (motive : S → S → Sort u) (a b c : S) + (h_1 : (a : Nat) → (c : S) → motive (S.mk1 a) c) + (h_2 : (a b c : S) → motive a (S.mk2 1 (b.mk4 c))) + (h_3 : (a : Bool) → (b c : S) → motive (S.mk3 a) (b.mk4 c)) + (h_4 : (a b : S) → motive a b) : + (∀ (a_1 : Nat), a = S.mk1 a_1 → False) → -- <<< Condition stating it is not the first case + f.match_1 motive a (S.mk2 1 (b.mk4 c)) h_1 h_2 h_3 h_4 = h_2 a b c + +f.match_1.eq_4 + (motive : S → S → Sort u) (a b : S) + (h_1 : (a : Nat) → (c : S) → motive (S.mk1 a) c) + (h_2 : (a b c : S) → motive a (S.mk2 1 (b.mk4 c))) + (h_3 : (a : Bool) → (b c : S) → motive (S.mk3 a) (b.mk4 c)) + (h_4 : (a b : S) → motive a b) : + (∀ (a_1 : Nat), a = S.mk1 a_1 → False) → -- <<< Condition stating it is not the first case + (∀ (b_1 c : S), b = S.mk2 1 (b_1.mk4 c) → False) → -- <<< Condition stating it is not the second case + (∀ (a_1 : Bool) (b_1 c : S), a = S.mk3 a_1 → b = b_1.mk4 c → False) → -- -- <<< Condition stating it is not the third case + f.match_1 motive a b h_1 h_2 h_3 h_4 = h_4 a b +``` +In the two equational theorems above, we have the following conditions. +``` +- `(∀ (a_1 : Nat), a = S.mk1 a_1 → False)` +- `(∀ (b_1 c : S), b = S.mk2 1 (b_1.mk4 c) → False)` +- `(∀ (a_1 : Bool) (b_1 c : S), a = S.mk3 a_1 → b = b_1.mk4 c → False)` +``` +When instantiating the equations (and `match`-splitter), we wrap the conditions with the gadget `Grind.MatchCond`. +This gadget is used for implementing truth-value propagation. See the propagator `propagateMatchCond` below. +For example, given a condition `C` of the form `Grind.MatchCond (∀ (a : Nat), t = S.mk1 a → False)`, +if `t` is merged with an equivalence class containing `S.mk2 n s`, then `C` is asseted to `true` by `propagateMatchCond`. + +This module also provides auxiliary functions for detecting congruences between `match`-expression conditions. +See function `collectMatchCondLhssAndAbstract`. + +Remark: This note highlights that the representation used for encoding `match`-expressions with +overlapping patterns is far from ideal for the `grind` module which operates with equivalence classes +and does not perform substitutions like `simp`. While modifying how `match`-expressions are encoded in Lean +would require major refactoring and affect many modules, this issue is important to acknowledge. +A different representation could simplify `grind`, but it could add extra complexity to other modules. +-/ /-- Returns `Grind.MatchCond e`. @@ -17,8 +78,8 @@ Recall that `Grind.MatchCond` is an identity function, but the following simproc is used to prevent the term `e` from being simplified, and we have special support for propagating is truth value. -/ -def markAsMatchCond (e : Expr) : MetaM Expr := - mkAppM ``Grind.MatchCond #[e] +def markAsMatchCond (e : Expr) : Expr := + mkApp (mkConst ``Grind.MatchCond) e builtin_dsimproc_decl reduceMatchCond (Grind.MatchCond _) := fun e => do let_expr Grind.MatchCond _ ← e | return .continue @@ -28,15 +89,103 @@ builtin_dsimproc_decl reduceMatchCond (Grind.MatchCond _) := fun e => do def addMatchCond (s : Simprocs) : CoreM Simprocs := do s.add ``reduceMatchCond (post := false) +/-- +Returns `some (lhs, rhs, isHEq)` if `e` is of the form +- `Eq _ lhs rhs` (`isHEq := false`), or +- `HEq _ lhs _ rhs` (`isHEq := true`) +-/ +private def isEqHEq? (e : Expr) : Option (Expr × Expr × Bool) := + match_expr e with + | Eq _ lhs rhs => some (lhs, rhs, false) + | HEq _ lhs _ rhs => some (lhs, rhs, true) + | _ => none + +/-- +Given `e` a `match`-expression condition, returns the left-hand side +of the ground equations. +-/ +private def collectMatchCondLhss (e : Expr) : Array Expr := Id.run do + let mut r := #[] + let mut e := e + repeat + let .forallE _ d b _ := e | return r + if let some (lhs, _, _) := isEqHEq? d then + unless lhs.hasLooseBVars do + r := r.push lhs + e := b + return r + +/-- +Replaces the left-hand side of an equality (or heterogeneous equality) `e` with `lhsNew`. +-/ +private def replaceLhs? (e : Expr) (lhsNew : Expr) : Option Expr := + match_expr e with + | f@Eq α lhs rhs => if lhs.hasLooseBVars then none else some (mkApp3 f α lhsNew rhs) + | f@HEq α lhs β rhs => if lhs.hasLooseBVars then none else some (mkApp4 f α lhsNew β rhs) + | _ => none + +/-- +Given `e` a `match`-expression condition, returns the left-hand side +of the ground equations, **and** function application that abstracts the left-hand sides. +As an example, assume we have a `match`-expression condition `C₁` of the form +``` +Grind.MatchCond (∀ y₁ y₂ y₃, t = .mk₁ y₁ → s = .mk₂ y₂ y₃ → False) +``` +then the result returned by this function is +``` +(#[t, s], (fun x₁ x₂ => (∀ y₁ y₂ y₃, x₁ = .mk₁ y₁ → x₂ = .mk₂ y₂ y₃ → False)) t s) +``` +Note that the returned expression is definitionally equal to `C₁`. +We use this expression to detect whether two different `match`-expression conditions are +congruent. +For example, suppose we also have the `match`-expression `C₂` of the form +``` +Grind.MatchCond (∀ y₁ y₂ y₃, a = .mk₁ y₁ → b = .mk₂ y₂ y₃ → False) +``` +This function would return +``` +(#[a, b], (fun x₁ x₂ => (∀ y₁ y₂ y₃, x₁ = .mk₁ y₁ → x₂ = .mk₂ y₂ y₃ → False)) a b) +``` +Note that the lambda abstraction is identical to the first one. Let's call it `l`. +Thus, we can write the two pairs above as +- `(#[t, s], l t s)` +- `(#[a, b], l a b)` +Moreover, `C₁` is definitionally equal to `l t s`, and `C₂` is definitionally equal to `l a b`. +Then, if `grind` infers that `t = a` and `s = b`, it will detect that `l t s` and `l a b` are +equal by congruence, and consequently `C₁` is equal to `C₂`. +-/ +def collectMatchCondLhssAndAbstract (matchCond : Expr) : GoalM (Array Expr × Expr) := do + let_expr Grind.MatchCond e := matchCond | return (#[], matchCond) + let lhss := collectMatchCondLhss e + let rec go (i : Nat) (xs : Array Expr) : GoalM Expr := do + if h : i < lhss.size then + let lhs := lhss[i] + withLocalDeclD ((`x).appendIndexAfter i) (← inferType lhs) fun x => + go (i+1) (xs.push x) + else + let rec replaceLhss (e : Expr) (i : Nat) : Expr := Id.run do + let .forallE _ d b _ := e | return e + if h : i < xs.size then + if let some dNew := replaceLhs? d xs[i] then + return e.updateForallE! dNew (replaceLhss b (i+1)) + else + return e.updateForallE! d (replaceLhss b i) + else + return e + let eAbst := replaceLhss e 0 + let eLam ← mkLambdaFVars xs eAbst + let e' := mkAppN eLam lhss + shareCommon e' + let e' ← go 0 #[] + return (lhss, e') + /-- Helper function for `isSatisfied`. See `isSatisfied`. -/ -private partial def isMathCondFalseHyp (e : Expr) : GoalM Bool := do - match_expr e with - | Eq _ lhs rhs => isFalse lhs rhs - | HEq _ lhs _ rhs => isFalse lhs rhs - | _ => return false +private partial def isMatchCondFalseHyp (e : Expr) : GoalM Bool := do + let some (lhs, rhs, _) := isEqHEq? e | return false + isFalse lhs rhs where isFalse (lhs rhs : Expr) : GoalM Bool := do if lhs.hasLooseBVars then return false @@ -91,18 +240,19 @@ private partial def isStatisfied (e : Expr) : GoalM Bool := do let mut e := e repeat let .forallE _ d b _ := e | break - if (← isMathCondFalseHyp d) then + if (← isMatchCondFalseHyp d) then trace[grind.debug.matchCond] "satifised{indentExpr e}\nthe following equality is false{indentExpr d}" return true e := b return false -private partial def mkMathCondProof? (e : Expr) : GoalM (Option Expr) := do +/-- Constructs a proof for a satisfied `match`-expression condition. -/ +private partial def mkMatchCondProof? (e : Expr) : GoalM (Option Expr) := do let_expr Grind.MatchCond f ← e | return none forallTelescopeReducing f fun xs _ => do for x in xs do let type ← inferType x - if (← isMathCondFalseHyp type) then + if (← isMatchCondFalseHyp type) then trace[grind.debug.matchCond] ">>> {type}" let some h ← go? x | pure () return some (← mkLambdaFVars xs h) @@ -110,10 +260,8 @@ private partial def mkMathCondProof? (e : Expr) : GoalM (Option Expr) := do where go? (h : Expr) : GoalM (Option Expr) := do trace[grind.debug.matchCond] "go?: {← inferType h}" - let (lhs, rhs, isHeq) ← match_expr (← inferType h) with - | Eq _ lhs rhs => pure (lhs, rhs, false) - | HEq _ lhs _ rhs => pure (lhs, rhs, true) - | _ => return none + let some (lhs, rhs, isHeq) := isEqHEq? (← inferType h) + | return none let target ← (← get).mvarId.getType let root ← getRootENode lhs let h ← if isHeq then @@ -147,7 +295,7 @@ where builtin_grind_propagator propagateMatchCond ↑Grind.MatchCond := fun e => do trace[grind.debug.matchCond] "visiting{indentExpr e}" if !(← isStatisfied e) then return () - let some h ← mkMathCondProof? e + let some h ← mkMatchCondProof? e | reportIssue m!"failed to construct proof for{indentExpr e}"; return () trace[grind.debug.matchCond] "{← inferType h}" pushEqTrue e <| mkEqTrueCore e h diff --git a/tests/lean/run/grind_match_eq_propagation.lean b/tests/lean/run/grind_match_eq_propagation.lean index 05d364fb6b..d958f54fde 100644 --- a/tests/lean/run/grind_match_eq_propagation.lean +++ b/tests/lean/run/grind_match_eq_propagation.lean @@ -12,36 +12,47 @@ def f (x y : S) := | _, _ => 5 example : f a b < 2 → b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → False := by - unfold f - grind (splits := 0) + grind (splits := 0) [f.eq_def] example : b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → f a b = 5 := by - unfold f - grind (splits := 0) + grind (splits := 0) [f.eq_def] example : b = .mk2 y1 y2 → y1 = 2 → a = .mk3 n → f a b = 4 := by - unfold f - grind (splits := 0) + grind (splits := 0) [f.eq_def] example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk3 n → f a b = 3 := by - unfold f - grind (splits := 0) + grind (splits := 0) [f.eq_def] example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk2 s3 s4 → f a b = 3 := by - unfold f - grind (splits := 0) + grind (splits := 0) [f.eq_def] + +example : f a b > 1 := by + grind (splits := 1) [f.eq_def] + +example : f a b > 1 := by + grind [f.eq_def] + +def g (x y : S) := + match x, y with + | .mk1 a, _ => a + 2 + | _, .mk2 1 (.mk4 _ _) => 3 + | .mk3 _, .mk4 _ _ => 4 + | _, _ => 5 + +example : g a b > 1 := by + grind [g.eq_def] inductive Vec (α : Type u) : Nat → Type u | nil : Vec α 0 | cons : α → Vec α n → Vec α (n+1) -def g (v w : Vec α n) : Nat := +def h (v w : Vec α n) : Nat := match v, w with | _, .cons _ (.cons _ _) => 20 | .nil, _ => 30 | _, _ => 40 --- TODO: introduce casts while instantiating equation theorems for `g.match_1` --- example (a b : Vec α 2) : g a b = 20 := by --- unfold g +-- TODO: introduce casts while instantiating equation theorems for `h.match_1` +-- example (a b : Vec α 2) : h a b = 20 := by +-- unfold h -- grind