fix: theory propagation in grind (#8198)

This PR fixes an issue in the theory propagation used in `grind`. When
two equivalence classes are merged, the core may need to push additional
equalities or disequalities down to the satellite theory solvers (e.g.,
`cutsat`, `comm ring`, etc). Some solvers (e.g. `cutsat`) assume that
all of the core’s invariants hold before they receive those facts.
Propagating immediately therefore risks violating a solver’s
pre-conditions midway through the merge. To decouple the merge operation
from propagation and to keep the core solver-agnostic, this PR adds the
helper type `PendingTheoryPropagation`.
This commit is contained in:
Leonardo de Moura 2025-05-01 19:19:56 -07:00 committed by GitHub
parent 1143b4766c
commit d26d7973ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 77 additions and 44 deletions

View file

@ -90,75 +90,116 @@ private partial def updateMT (root : Expr) : GoalM Unit := do
updateMT parent
/--
Helper function for combining `ENode.offset?` fields and propagating equalities
to the offset constraint module.
Equalities or disequalities to be propagated to a theory solver **after**
two equivalence classes have been merged.
Some solvers (e.g. `cutsat`) require the core data structures to satisfy
their invariants. During the merge operations some of these invariants do not hold.
Thus, we first *record* the facts that must be propagated in a `PendingTheoryPropagation` value,
complete the merge, and only then perform the propagation.
We now use this workflow for *all* theory solvers, even when a particular
solver does not rely on these invariants. This keeps the core
solver-agnostic and lets us modify solvers without further adjustments.
-/
private def propagateOffsetEq (rhsRoot lhsRoot : ENode) : GoalM Unit := do
inductive PendingTheoryPropagation where
| /-- Nothing to propagate. -/
none
| /-- Propagate the equality `lhs = rhs`. -/
eq (lhs rhs : Expr)
|
/--
Propagate the literal equality `lhs = lit`.
This is needed because some solvers do not internalize literal values.
Remark: we may remove this optimization in the future because it adds complexity
for a small performance gain.
-/
eqLit (lhs lit : Expr)
| /-- Propagate the disequalities in `ps`. -/
diseqs (ps : ParentSet)
/--
Helper function for combining `ENode.offset?` fields and detecting what needs
to be propagated to the offset constraint module.
-/
private def checkOffsetEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
match lhsRoot.offset? with
| some lhsOffset =>
if let some rhsOffset := rhsRoot.offset? then
Arith.Offset.processNewEq lhsOffset rhsOffset
return .eq lhsOffset rhsOffset
else if isNatNum rhsRoot.self then
Arith.Offset.processNewEqLit lhsOffset rhsRoot.self
return .eqLit lhsOffset rhsRoot.self
else
-- We have to retrieve the node because other fields have been updated
let rhsRoot ← getENode rhsRoot.self
setENode rhsRoot.self { rhsRoot with offset? := lhsOffset }
return .none
| none =>
if isNatNum lhsRoot.self then
if let some rhsOffset := rhsRoot.offset? then
Arith.Offset.processNewEqLit rhsOffset lhsRoot.self
if let some rhsOffset := rhsRoot.offset? then
return .eqLit rhsOffset lhsRoot.self
return .none
def propagateOffset : PendingTheoryPropagation → GoalM Unit
| .eq lhs rhs => Arith.Offset.processNewEq lhs rhs
| .eqLit lhs lit => Arith.Offset.processNewEqLit lhs lit
| _ => return ()
/--
Helper function for combining `ENode.cutsat?` fields and propagating equalities
to the cutsat module.
It returns a set of parents that should be traversed for disequality propagation.
Helper function for combining `ENode.cutsat?` fields and detecting what needs
to be propagated to the cutsat module.
-/
private def propagateCutsatEq (rhsRoot lhsRoot : ENode) : GoalM ParentSet := do
private def checkCutsatEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
match lhsRoot.cutsat? with
| some lhsCutsat =>
if let some rhsCutsat := rhsRoot.cutsat? then
Arith.Cutsat.processNewEq lhsCutsat rhsCutsat
return {}
return .eq lhsCutsat rhsCutsat
else if isNum rhsRoot.self then
Arith.Cutsat.processNewEqLit lhsCutsat rhsRoot.self
return {}
return .eqLit lhsCutsat rhsRoot.self
else
-- We have to retrieve the node because other fields have been updated
let rhsRoot ← getENode rhsRoot.self
setENode rhsRoot.self { rhsRoot with cutsat? := lhsCutsat }
getParents rhsRoot.self
return .diseqs (← getParents rhsRoot.self)
| none =>
if let some rhsCutsat := rhsRoot.cutsat? then
if isNum lhsRoot.self then
Arith.Cutsat.processNewEqLit rhsCutsat lhsRoot.self
return {}
return .eqLit rhsCutsat lhsRoot.self
else
getParents lhsRoot.self
return .diseqs (← getParents lhsRoot.self)
else
return {}
return .none
def propagateCutsat : PendingTheoryPropagation → GoalM Unit
| .eq lhs rhs => Arith.Cutsat.processNewEq lhs rhs
| .eqLit lhs lit => Arith.Cutsat.processNewEqLit lhs lit
| .diseqs ps => propagateCutsatDiseqs ps
| .none => return ()
/--
Helper function for combining `ENode.ring?` fields and propagating equalities
to the commutative ring module.
It returns a set of parents that should be traversed for disequality propagation.
Helper function for combining `ENode.ring?` fields and detecting what needs to be
progagated to the commutative ring module.
-/
private def propagateCommRingEq (rhsRoot lhsRoot : ENode) : GoalM ParentSet := do
private def checkCommRingEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
match lhsRoot.ring? with
| some lhsRing =>
if let some rhsRing := rhsRoot.ring? then
Arith.CommRing.processNewEq lhsRing rhsRing
return {}
return .eq lhsRing rhsRing
else
-- We have to retrieve the node because other fields have been updated
let rhsRoot ← getENode rhsRoot.self
setENode rhsRoot.self { rhsRoot with ring? := lhsRing }
getParents rhsRoot.self
return .diseqs (← getParents rhsRoot.self)
| none =>
if rhsRoot.ring?.isSome then
getParents lhsRoot.self
return .diseqs (← getParents lhsRoot.self)
else
return {}
return .none
def propagateCommRing : PendingTheoryPropagation → GoalM Unit
| .eq lhs rhs => Arith.CommRing.processNewEq lhs rhs
| .diseqs ps => propagateCommRingDiseqs ps
| _ => return ()
/--
Tries to apply beta-reductiong using the parent applications of the functions in `fns` with
@ -262,9 +303,9 @@ where
}
propagateBeta lams₁ fns₁
propagateBeta lams₂ fns₂
propagateOffsetEq rhsRoot lhsRoot
let parentsToPropagateCutsatDiseqs ← propagateCutsatEq rhsRoot lhsRoot
let parentsToPropagateRingDiseqs ← propagateCommRingEq rhsRoot lhsRoot
let offsetTodo ← checkOffsetEq rhsRoot lhsRoot
let cutsatTodo ← checkCutsatEq rhsRoot lhsRoot
let ringTodo ← checkCommRingEq rhsRoot lhsRoot
resetParentsOf lhsRoot.self
copyParentsTo parents rhsNode.root
unless (← isInconsistent) do
@ -274,8 +315,9 @@ where
propagateUp parent
for e in toPropagateDown do
propagateDown e
propagateCutsatDiseqs parentsToPropagateCutsatDiseqs
propagateCommRingDiseqs parentsToPropagateRingDiseqs
propagateOffset offsetTodo
propagateCutsat cutsatTodo
propagateCommRing ringTodo
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
traverseEqc lhs fun n =>
setENode n.self { n with root := rootNew }

View file

@ -1,5 +1,5 @@
open Array
set_option grind.warning false
reset_grind_attrs%
attribute [grind] Vector.getElem_swap_of_ne
@ -15,15 +15,6 @@ theorem qpartition_loop_spec₁ {n} (lt : αα → Bool) (lo hi : Nat)
∀ k, (h₁ : lo ≤ k) → (h₂ : k < mid) → lt as'[k] as'[mid] := by
sorry
/--
warning: The `grind` tactic is experimental and still under development. Avoid using it in production projects.
---
error: internal `grind` error, failed to build disequality proof for
(lo + hi) / 2
and
lo
-/
#guard_msgs in
example {n} (lt : αα → Bool) (lo hi : Nat)
(hlo : lo < n := by omega) (hhi : hi < n := by omega) (w : lo ≤ hi := by omega)
(as : Vector α n) (mid as')