From eb5ad2c03a1ec4bb4d2fecb63a3cef1076d62a2a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 25 Feb 2025 19:34:39 -0800 Subject: [PATCH] feat: disequality propagation from `grind` core module to cutsat (#7234) This PR implements dIsequality propagation from `grind` core module to cutsat. --- src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean | 3 + .../Tactic/Grind/Arith/Cutsat/EqCnstr.lean | 9 ++- src/Lean/Meta/Tactic/Grind/Core.lean | 9 ++- src/Lean/Meta/Tactic/Grind/Diseq.lean | 6 -- src/Lean/Meta/Tactic/Grind/Propagate.lean | 1 + src/Lean/Meta/Tactic/Grind/Types.lean | 56 ++++++++++++++++- tests/lean/run/grind_cutsat_diseq_1.lean | 63 +++++++++++++++++++ 7 files changed, 136 insertions(+), 11 deletions(-) create mode 100644 tests/lean/run/grind_cutsat_diseq_1.lean diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean index c53d69626f..0e5e55c1ff 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean @@ -43,4 +43,7 @@ builtin_initialize registerTraceClass `grind.cutsat.le.upper (inherited := true) builtin_initialize registerTraceClass `grind.cutsat.assign builtin_initialize registerTraceClass `grind.cutsat.conflict +builtin_initialize registerTraceClass `grind.debug.cutsat.eq +builtin_initialize registerTraceClass `grind.debug.cutsat.diseq + end Lean diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean index 89dbe10773..b6f72b357f 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean @@ -151,14 +151,16 @@ private def exprAsPoly (a : Expr) : GoalM Poly := do @[export lean_process_cutsat_eq] def processNewEqImpl (a b : Expr) : GoalM Unit := do + trace[grind.debug.cutsat.eq] "{a} = {b}" let p₁ ← exprAsPoly a let p₂ ← exprAsPoly b let p := p₁.combine (p₂.mul (-1)) let c ← mkEqCnstr p (.core p₁ p₂ (← mkEqProof a b)) c.assert -@[export lean_process_new_cutsat_lit] +@[export lean_process_cutsat_eq_lit] def processNewEqLitImpl (a ke : Expr) : GoalM Unit := do + trace[grind.debug.cutsat.eq] "{a} = {ke}" let some k ← getIntValue? ke | return () let p₁ ← exprAsPoly a let h ← mkEqProof a ke @@ -170,6 +172,11 @@ def processNewEqLitImpl (a ke : Expr) : GoalM Unit := do mkEqCnstr p (.core p₁ p₂ h) c.assert +@[export lean_process_cutsat_diseq] +def processNewDiseqImpl (a b : Expr) : GoalM Unit := do + trace[grind.debug.cutsat.diseq] "{a} ≠ {b}" + -- TODO + /-- Different kinds of terms internalized by this module. -/ private inductive SupportedTermKind where | add | mul | num diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index 4723e1b149..c1c028db18 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -124,10 +124,13 @@ private def propagateCutsatEq (rhsRoot lhsRoot : ENode) : GoalM Unit := do -- We have to retrieve the node because other fields have been updated let rhsRoot ← getENode rhsRoot.self setENode rhsRoot.self { rhsRoot with cutsat? := lhsCutsat } + propagateCutsatDiseqs rhsRoot.self | none => - if isIntNum lhsRoot.self then if let some rhsCutsat := rhsRoot.cutsat? then - Arith.Cutsat.processNewEqLit rhsCutsat lhsRoot.self + if isIntNum lhsRoot.self then + Arith.Cutsat.processNewEqLit rhsCutsat lhsRoot.self + else + propagateCutsatDiseqs lhsRoot.self /-- Tries to apply beta-reductiong using the parent applications of the functions in `fns` with @@ -225,9 +228,9 @@ where } propagateBeta lams₁ fns₁ propagateBeta lams₂ fns₂ - resetParentsOf lhsRoot.self propagateOffsetEq rhsRoot lhsRoot propagateCutsatEq rhsRoot lhsRoot + resetParentsOf lhsRoot.self copyParentsTo parents rhsNode.root unless (← isInconsistent) do updateMT rhsRoot.self diff --git a/src/Lean/Meta/Tactic/Grind/Diseq.lean b/src/Lean/Meta/Tactic/Grind/Diseq.lean index 6c2d63312d..1f5d223ae0 100644 --- a/src/Lean/Meta/Tactic/Grind/Diseq.lean +++ b/src/Lean/Meta/Tactic/Grind/Diseq.lean @@ -9,12 +9,6 @@ import Lean.Meta.Tactic.Grind.Types namespace Lean.Meta.Grind -/-- -Returns `true` if type of `t` is definitionally equal to `α` --/ -private def hasType (t α : Expr) : MetaM Bool := - withDefault do isDefEq (← inferType t) α - /-- Returns `some (c = d)` if - `c = d` and `False` are in the same equivalence class, and diff --git a/src/Lean/Meta/Tactic/Grind/Propagate.lean b/src/Lean/Meta/Tactic/Grind/Propagate.lean index 5a8c4057e3..6069f7dd19 100644 --- a/src/Lean/Meta/Tactic/Grind/Propagate.lean +++ b/src/Lean/Meta/Tactic/Grind/Propagate.lean @@ -146,6 +146,7 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do pushEq a b <| mkOfEqTrueCore e (← mkEqTrueProof e) else if (← isEqFalse e) then let_expr Eq α lhs rhs := e | return () + propagateCutsatDiseq lhs rhs let thms ← getExtTheorems α if !thms.isEmpty then /- diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index ae4bf58f0f..93541bd4fc 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -866,9 +866,16 @@ opaque Arith.Cutsat.processNewEq (a b : Expr) : GoalM Unit Notifies the cutsat module that `a = k` where `a` is term that has been internalized by this module, and `k` is a numeral. -/ -@[extern "lean_process_new_cutsat_lit"] -- forward definition +@[extern "lean_process_cutsat_eq_lit"] -- forward definition opaque Arith.Cutsat.processNewEqLit (a k : Expr) : GoalM Unit +/-- +Notifies the cutsat module that `a ≠ b` where +`a` and `b` are terms that have been internalized by this module. +-/ +@[extern "lean_process_cutsat_diseq"] -- forward definition +opaque Arith.Cutsat.processNewDiseq (a b : Expr) : GoalM Unit + /-- Returns `true` if `e` is a nonegative numeral and has type `Int`. -/ def isNonnegIntNum (e : Expr) : Bool := Id.run do let_expr OfNat.ofNat _ _ inst := e | false @@ -883,6 +890,52 @@ def isIntNum (e : Expr) : Bool := isNonnegIntNum e | _ => isNonnegIntNum e +/-- +Returns `true` if type of `t` is definitionally equal to `α` +-/ +def hasType (t α : Expr) : MetaM Bool := + withDefault do isDefEq (← inferType t) α + +/-- +For each equality `b = c` in `parents`, executes `k b c` IF +- `b = c` is equal to `False`, and +- `a` is the equivalence class of `b` or `c`, and +- type of `a` is definitionally equal to types of `b` and `c`. +-/ +@[inline] def forEachDiseqOfCore (a : Expr) (parents : ParentSet) (k : (lhs : Expr) → (rhs : Expr) → GoalM Unit) : GoalM Unit := do + for parent in parents do + let_expr Eq α b c := parent | continue + if (← isEqFalse parent) then + if (← isEqv a b <||> isEqv a c) then + if (← hasType a α) then + k b c + +/-- +For each equality `b = c` in `(← getParents a)`, executes `k b c` IF +- `b = c` is equal to `False`, and +- `a` is the equivalence class of `b` or `c`, and +- type of `a` is definitionally equal to types of `b` and `c`. +-/ +@[inline] def forEachDiseqOf (a : Expr) (k : (lhs : Expr) → (rhs : Expr) → GoalM Unit) : GoalM Unit := do + forEachDiseqOfCore a (← getParents a) k + +/-- +Given `lhs` and `rhs` that are known to be disequal, checks whether +`lhs` and `rhs` have cutsat terms `e₁` and `e₂` attached to them, +and invokes process `Arith.Cutsat.processNewDiseq e₁ e₂` +-/ +def propagateCutsatDiseq (lhs rhs : Expr) : GoalM Unit := do + let { cutsat? := some e₁, .. } ← getRootENode lhs | return () + let { cutsat? := some e₂, .. } ← getRootENode rhs | return () + Arith.Cutsat.processNewDiseq e₁ e₂ + +/-- +Traverses all known disequalities about `e`, and propagate the ones relevant to the +cutsat module. +-/ +def propagateCutsatDiseqs (e : Expr) : GoalM Unit := do + forEachDiseqOf e propagateCutsatDiseq + /-- Marks `e` as a term of interest to the cutsat module. If the root of `e`s equivalence class has already a term of interest, @@ -896,6 +949,7 @@ def markAsCutsatTerm (e : Expr) : GoalM Unit := do Arith.Cutsat.processNewEqLit e root.self else setENode root.self { root with cutsat? := some e } + propagateCutsatDiseqs root.self /-- Returns `true` is `e` is the root of its congruence class. -/ def isCongrRoot (e : Expr) : GoalM Bool := do diff --git a/tests/lean/run/grind_cutsat_diseq_1.lean b/tests/lean/run/grind_cutsat_diseq_1.lean new file mode 100644 index 0000000000..d220490402 --- /dev/null +++ b/tests/lean/run/grind_cutsat_diseq_1.lean @@ -0,0 +1,63 @@ +set_option grind.warning false +set_option grind.debug true +open Int.Linear + +set_option trace.grind.debug.cutsat.diseq true +set_option trace.grind.debug.cutsat.eq true + +/-- info: [grind.debug.cutsat.diseq] a ≠ b -/ +#guard_msgs (info) in +example (a b : Int) : a + b < 0 → a ≠ b → False := by + (fail_if_success grind); sorry + +#guard_msgs (info) in -- `a` and `b` are not relevant to cutsat in the following example +example (a b : Int) : a ≠ b → False := by + (fail_if_success grind); sorry + +/-- info: [grind.debug.cutsat.diseq] a ≠ b -/ +#guard_msgs (info) in +example (a b : Int) : a ≠ b → a + b < 0 → False := by + (fail_if_success grind); sorry + +/-- info: [grind.debug.cutsat.diseq] a ≠ b -/ +#guard_msgs (info) in +example (a b c : Int) : a ≠ c → c = b → a + b < 0 → False := by + (fail_if_success grind); sorry + +/-- info: [grind.debug.cutsat.diseq] a ≠ b -/ +#guard_msgs (info) in +example (a b c d : Int) : d ≠ c → c = b → a = d → a + b < 0 → False := by + (fail_if_success grind); sorry + +/-- info: [grind.debug.cutsat.diseq] a ≠ b -/ +#guard_msgs (info) in +example (a b c d : Int) : d ≠ c → a = d → a + b < 0 → c = b → False := by + (fail_if_success grind); sorry + +/-- +info: [grind.debug.cutsat.diseq] a ≠ b +[grind.debug.cutsat.eq] e = b +-/ +#guard_msgs (info) in +example (a b c d e : Int) : d ≠ c → a = d → a + b < 0 → c = b → c = e → e > 0 → False := by + (fail_if_success grind); sorry + +/-- +info: [grind.debug.cutsat.eq] b = e +[grind.debug.cutsat.diseq] a ≠ e +-/ +#guard_msgs (info) in +example (a b c d e : Int) : d ≠ c → a = d → c = b → c = e → e > 0 → a + b < 0 → False := by + (fail_if_success grind); sorry + +/-- +info: [grind.debug.cutsat.eq] b = e +[grind.debug.cutsat.diseq] a ≠ e +-/ +#guard_msgs (info) in +example (a b c d e : Int) : a = d → c = b → c = e → e > 0 → a + b < 0 → d ≠ c → False := by + (fail_if_success grind); sorry + +#guard_msgs (info) in -- no propagation to cutsat +example (a b c d e : Int) : a = d → c = b → c = e → a = 1 → d ≠ c → False := by + (fail_if_success grind); sorry