feat: disequality propagation from grind core module to cutsat (#7234)

This PR implements dIsequality propagation from `grind` core module to
cutsat.
This commit is contained in:
Leonardo de Moura 2025-02-25 19:34:39 -08:00 committed by GitHub
parent 769fe4ebf6
commit eb5ad2c03a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 136 additions and 11 deletions

View file

@ -43,4 +43,7 @@ builtin_initialize registerTraceClass `grind.cutsat.le.upper (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.assign
builtin_initialize registerTraceClass `grind.cutsat.conflict
builtin_initialize registerTraceClass `grind.debug.cutsat.eq
builtin_initialize registerTraceClass `grind.debug.cutsat.diseq
end Lean

View file

@ -151,14 +151,16 @@ private def exprAsPoly (a : Expr) : GoalM Poly := do
@[export lean_process_cutsat_eq]
def processNewEqImpl (a b : Expr) : GoalM Unit := do
trace[grind.debug.cutsat.eq] "{a} = {b}"
let p₁ ← exprAsPoly a
let p₂ ← exprAsPoly b
let p := p₁.combine (p₂.mul (-1))
let c ← mkEqCnstr p (.core p₁ p₂ (← mkEqProof a b))
c.assert
@[export lean_process_new_cutsat_lit]
@[export lean_process_cutsat_eq_lit]
def processNewEqLitImpl (a ke : Expr) : GoalM Unit := do
trace[grind.debug.cutsat.eq] "{a} = {ke}"
let some k ← getIntValue? ke | return ()
let p₁ ← exprAsPoly a
let h ← mkEqProof a ke
@ -170,6 +172,11 @@ def processNewEqLitImpl (a ke : Expr) : GoalM Unit := do
mkEqCnstr p (.core p₁ p₂ h)
c.assert
@[export lean_process_cutsat_diseq]
def processNewDiseqImpl (a b : Expr) : GoalM Unit := do
trace[grind.debug.cutsat.diseq] "{a} ≠ {b}"
-- TODO
/-- Different kinds of terms internalized by this module. -/
private inductive SupportedTermKind where
| add | mul | num

View file

@ -124,10 +124,13 @@ private def propagateCutsatEq (rhsRoot lhsRoot : ENode) : GoalM Unit := do
-- We have to retrieve the node because other fields have been updated
let rhsRoot ← getENode rhsRoot.self
setENode rhsRoot.self { rhsRoot with cutsat? := lhsCutsat }
propagateCutsatDiseqs rhsRoot.self
| none =>
if isIntNum lhsRoot.self then
if let some rhsCutsat := rhsRoot.cutsat? then
Arith.Cutsat.processNewEqLit rhsCutsat lhsRoot.self
if isIntNum lhsRoot.self then
Arith.Cutsat.processNewEqLit rhsCutsat lhsRoot.self
else
propagateCutsatDiseqs lhsRoot.self
/--
Tries to apply beta-reductiong using the parent applications of the functions in `fns` with
@ -225,9 +228,9 @@ where
}
propagateBeta lams₁ fns₁
propagateBeta lams₂ fns₂
resetParentsOf lhsRoot.self
propagateOffsetEq rhsRoot lhsRoot
propagateCutsatEq rhsRoot lhsRoot
resetParentsOf lhsRoot.self
copyParentsTo parents rhsNode.root
unless (← isInconsistent) do
updateMT rhsRoot.self

View file

@ -9,12 +9,6 @@ import Lean.Meta.Tactic.Grind.Types
namespace Lean.Meta.Grind
/--
Returns `true` if type of `t` is definitionally equal to `α`
-/
private def hasType (t α : Expr) : MetaM Bool :=
withDefault do isDefEq (← inferType t) α
/--
Returns `some (c = d)` if
- `c = d` and `False` are in the same equivalence class, and

View file

@ -146,6 +146,7 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do
pushEq a b <| mkOfEqTrueCore e (← mkEqTrueProof e)
else if (← isEqFalse e) then
let_expr Eq α lhs rhs := e | return ()
propagateCutsatDiseq lhs rhs
let thms ← getExtTheorems α
if !thms.isEmpty then
/-

View file

@ -866,9 +866,16 @@ opaque Arith.Cutsat.processNewEq (a b : Expr) : GoalM Unit
Notifies the cutsat module that `a = k` where
`a` is term that has been internalized by this module, and `k` is a numeral.
-/
@[extern "lean_process_new_cutsat_lit"] -- forward definition
@[extern "lean_process_cutsat_eq_lit"] -- forward definition
opaque Arith.Cutsat.processNewEqLit (a k : Expr) : GoalM Unit
/--
Notifies the cutsat module that `a ≠ b` where
`a` and `b` are terms that have been internalized by this module.
-/
@[extern "lean_process_cutsat_diseq"] -- forward definition
opaque Arith.Cutsat.processNewDiseq (a b : Expr) : GoalM Unit
/-- Returns `true` if `e` is a nonegative numeral and has type `Int`. -/
def isNonnegIntNum (e : Expr) : Bool := Id.run do
let_expr OfNat.ofNat _ _ inst := e | false
@ -883,6 +890,52 @@ def isIntNum (e : Expr) : Bool :=
isNonnegIntNum e
| _ => isNonnegIntNum e
/--
Returns `true` if type of `t` is definitionally equal to `α`
-/
def hasType (t α : Expr) : MetaM Bool :=
withDefault do isDefEq (← inferType t) α
/--
For each equality `b = c` in `parents`, executes `k b c` IF
- `b = c` is equal to `False`, and
- `a` is the equivalence class of `b` or `c`, and
- type of `a` is definitionally equal to types of `b` and `c`.
-/
@[inline] def forEachDiseqOfCore (a : Expr) (parents : ParentSet) (k : (lhs : Expr) → (rhs : Expr) → GoalM Unit) : GoalM Unit := do
for parent in parents do
let_expr Eq α b c := parent | continue
if (← isEqFalse parent) then
if (← isEqv a b <||> isEqv a c) then
if (← hasType a α) then
k b c
/--
For each equality `b = c` in `(← getParents a)`, executes `k b c` IF
- `b = c` is equal to `False`, and
- `a` is the equivalence class of `b` or `c`, and
- type of `a` is definitionally equal to types of `b` and `c`.
-/
@[inline] def forEachDiseqOf (a : Expr) (k : (lhs : Expr) → (rhs : Expr) → GoalM Unit) : GoalM Unit := do
forEachDiseqOfCore a (← getParents a) k
/--
Given `lhs` and `rhs` that are known to be disequal, checks whether
`lhs` and `rhs` have cutsat terms `e₁` and `e₂` attached to them,
and invokes process `Arith.Cutsat.processNewDiseq e₁ e₂`
-/
def propagateCutsatDiseq (lhs rhs : Expr) : GoalM Unit := do
let { cutsat? := some e₁, .. } ← getRootENode lhs | return ()
let { cutsat? := some e₂, .. } ← getRootENode rhs | return ()
Arith.Cutsat.processNewDiseq e₁ e₂
/--
Traverses all known disequalities about `e`, and propagate the ones relevant to the
cutsat module.
-/
def propagateCutsatDiseqs (e : Expr) : GoalM Unit := do
forEachDiseqOf e propagateCutsatDiseq
/--
Marks `e` as a term of interest to the cutsat module.
If the root of `e`s equivalence class has already a term of interest,
@ -896,6 +949,7 @@ def markAsCutsatTerm (e : Expr) : GoalM Unit := do
Arith.Cutsat.processNewEqLit e root.self
else
setENode root.self { root with cutsat? := some e }
propagateCutsatDiseqs root.self
/-- Returns `true` is `e` is the root of its congruence class. -/
def isCongrRoot (e : Expr) : GoalM Bool := do

View file

@ -0,0 +1,63 @@
set_option grind.warning false
set_option grind.debug true
open Int.Linear
set_option trace.grind.debug.cutsat.diseq true
set_option trace.grind.debug.cutsat.eq true
/-- info: [grind.debug.cutsat.diseq] a ≠ b -/
#guard_msgs (info) in
example (a b : Int) : a + b < 0 → a ≠ b → False := by
(fail_if_success grind); sorry
#guard_msgs (info) in -- `a` and `b` are not relevant to cutsat in the following example
example (a b : Int) : a ≠ b → False := by
(fail_if_success grind); sorry
/-- info: [grind.debug.cutsat.diseq] a ≠ b -/
#guard_msgs (info) in
example (a b : Int) : a ≠ b → a + b < 0 → False := by
(fail_if_success grind); sorry
/-- info: [grind.debug.cutsat.diseq] a ≠ b -/
#guard_msgs (info) in
example (a b c : Int) : a ≠ c → c = b → a + b < 0 → False := by
(fail_if_success grind); sorry
/-- info: [grind.debug.cutsat.diseq] a ≠ b -/
#guard_msgs (info) in
example (a b c d : Int) : d ≠ c → c = b → a = d → a + b < 0 → False := by
(fail_if_success grind); sorry
/-- info: [grind.debug.cutsat.diseq] a ≠ b -/
#guard_msgs (info) in
example (a b c d : Int) : d ≠ c → a = d → a + b < 0 → c = b → False := by
(fail_if_success grind); sorry
/--
info: [grind.debug.cutsat.diseq] a ≠ b
[grind.debug.cutsat.eq] e = b
-/
#guard_msgs (info) in
example (a b c d e : Int) : d ≠ c → a = d → a + b < 0 → c = b → c = e → e > 0 → False := by
(fail_if_success grind); sorry
/--
info: [grind.debug.cutsat.eq] b = e
[grind.debug.cutsat.diseq] a ≠ e
-/
#guard_msgs (info) in
example (a b c d e : Int) : d ≠ c → a = d → c = b → c = e → e > 0 → a + b < 0 → False := by
(fail_if_success grind); sorry
/--
info: [grind.debug.cutsat.eq] b = e
[grind.debug.cutsat.diseq] a ≠ e
-/
#guard_msgs (info) in
example (a b c d e : Int) : a = d → c = b → c = e → e > 0 → a + b < 0 → d ≠ c → False := by
(fail_if_success grind); sorry
#guard_msgs (info) in -- no propagation to cutsat
example (a b c d e : Int) : a = d → c = b → c = e → a = 1 → d ≠ c → False := by
(fail_if_success grind); sorry