fix: theory propagation in grind (#8198)
This PR fixes an issue in the theory propagation used in `grind`. When two equivalence classes are merged, the core may need to push additional equalities or disequalities down to the satellite theory solvers (e.g., `cutsat`, `comm ring`, etc). Some solvers (e.g. `cutsat`) assume that all of the core’s invariants hold before they receive those facts. Propagating immediately therefore risks violating a solver’s pre-conditions midway through the merge. To decouple the merge operation from propagation and to keep the core solver-agnostic, this PR adds the helper type `PendingTheoryPropagation`.
This commit is contained in:
parent
1143b4766c
commit
d26d7973ad
2 changed files with 77 additions and 44 deletions
|
|
@ -90,75 +90,116 @@ private partial def updateMT (root : Expr) : GoalM Unit := do
|
|||
updateMT parent
|
||||
|
||||
/--
|
||||
Helper function for combining `ENode.offset?` fields and propagating equalities
|
||||
to the offset constraint module.
|
||||
Equalities or disequalities to be propagated to a theory solver **after**
|
||||
two equivalence classes have been merged.
|
||||
|
||||
Some solvers (e.g. `cutsat`) require the core data structures to satisfy
|
||||
their invariants. During the merge operations some of these invariants do not hold.
|
||||
Thus, we first *record* the facts that must be propagated in a `PendingTheoryPropagation` value,
|
||||
complete the merge, and only then perform the propagation.
|
||||
|
||||
We now use this workflow for *all* theory solvers, even when a particular
|
||||
solver does not rely on these invariants. This keeps the core
|
||||
solver-agnostic and lets us modify solvers without further adjustments.
|
||||
-/
|
||||
private def propagateOffsetEq (rhsRoot lhsRoot : ENode) : GoalM Unit := do
|
||||
inductive PendingTheoryPropagation where
|
||||
| /-- Nothing to propagate. -/
|
||||
none
|
||||
| /-- Propagate the equality `lhs = rhs`. -/
|
||||
eq (lhs rhs : Expr)
|
||||
|
|
||||
/--
|
||||
Propagate the literal equality `lhs = lit`.
|
||||
This is needed because some solvers do not internalize literal values.
|
||||
Remark: we may remove this optimization in the future because it adds complexity
|
||||
for a small performance gain.
|
||||
-/
|
||||
eqLit (lhs lit : Expr)
|
||||
| /-- Propagate the disequalities in `ps`. -/
|
||||
diseqs (ps : ParentSet)
|
||||
|
||||
/--
|
||||
Helper function for combining `ENode.offset?` fields and detecting what needs
|
||||
to be propagated to the offset constraint module.
|
||||
-/
|
||||
private def checkOffsetEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
|
||||
match lhsRoot.offset? with
|
||||
| some lhsOffset =>
|
||||
if let some rhsOffset := rhsRoot.offset? then
|
||||
Arith.Offset.processNewEq lhsOffset rhsOffset
|
||||
return .eq lhsOffset rhsOffset
|
||||
else if isNatNum rhsRoot.self then
|
||||
Arith.Offset.processNewEqLit lhsOffset rhsRoot.self
|
||||
return .eqLit lhsOffset rhsRoot.self
|
||||
else
|
||||
-- We have to retrieve the node because other fields have been updated
|
||||
let rhsRoot ← getENode rhsRoot.self
|
||||
setENode rhsRoot.self { rhsRoot with offset? := lhsOffset }
|
||||
return .none
|
||||
| none =>
|
||||
if isNatNum lhsRoot.self then
|
||||
if let some rhsOffset := rhsRoot.offset? then
|
||||
Arith.Offset.processNewEqLit rhsOffset lhsRoot.self
|
||||
if let some rhsOffset := rhsRoot.offset? then
|
||||
return .eqLit rhsOffset lhsRoot.self
|
||||
return .none
|
||||
|
||||
def propagateOffset : PendingTheoryPropagation → GoalM Unit
|
||||
| .eq lhs rhs => Arith.Offset.processNewEq lhs rhs
|
||||
| .eqLit lhs lit => Arith.Offset.processNewEqLit lhs lit
|
||||
| _ => return ()
|
||||
|
||||
/--
|
||||
Helper function for combining `ENode.cutsat?` fields and propagating equalities
|
||||
to the cutsat module.
|
||||
It returns a set of parents that should be traversed for disequality propagation.
|
||||
Helper function for combining `ENode.cutsat?` fields and detecting what needs
|
||||
to be propagated to the cutsat module.
|
||||
-/
|
||||
private def propagateCutsatEq (rhsRoot lhsRoot : ENode) : GoalM ParentSet := do
|
||||
private def checkCutsatEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
|
||||
match lhsRoot.cutsat? with
|
||||
| some lhsCutsat =>
|
||||
if let some rhsCutsat := rhsRoot.cutsat? then
|
||||
Arith.Cutsat.processNewEq lhsCutsat rhsCutsat
|
||||
return {}
|
||||
return .eq lhsCutsat rhsCutsat
|
||||
else if isNum rhsRoot.self then
|
||||
Arith.Cutsat.processNewEqLit lhsCutsat rhsRoot.self
|
||||
return {}
|
||||
return .eqLit lhsCutsat rhsRoot.self
|
||||
else
|
||||
-- We have to retrieve the node because other fields have been updated
|
||||
let rhsRoot ← getENode rhsRoot.self
|
||||
setENode rhsRoot.self { rhsRoot with cutsat? := lhsCutsat }
|
||||
getParents rhsRoot.self
|
||||
return .diseqs (← getParents rhsRoot.self)
|
||||
| none =>
|
||||
if let some rhsCutsat := rhsRoot.cutsat? then
|
||||
if isNum lhsRoot.self then
|
||||
Arith.Cutsat.processNewEqLit rhsCutsat lhsRoot.self
|
||||
return {}
|
||||
return .eqLit rhsCutsat lhsRoot.self
|
||||
else
|
||||
getParents lhsRoot.self
|
||||
return .diseqs (← getParents lhsRoot.self)
|
||||
else
|
||||
return {}
|
||||
return .none
|
||||
|
||||
def propagateCutsat : PendingTheoryPropagation → GoalM Unit
|
||||
| .eq lhs rhs => Arith.Cutsat.processNewEq lhs rhs
|
||||
| .eqLit lhs lit => Arith.Cutsat.processNewEqLit lhs lit
|
||||
| .diseqs ps => propagateCutsatDiseqs ps
|
||||
| .none => return ()
|
||||
|
||||
/--
|
||||
Helper function for combining `ENode.ring?` fields and propagating equalities
|
||||
to the commutative ring module.
|
||||
It returns a set of parents that should be traversed for disequality propagation.
|
||||
Helper function for combining `ENode.ring?` fields and detecting what needs to be
|
||||
progagated to the commutative ring module.
|
||||
-/
|
||||
private def propagateCommRingEq (rhsRoot lhsRoot : ENode) : GoalM ParentSet := do
|
||||
private def checkCommRingEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
|
||||
match lhsRoot.ring? with
|
||||
| some lhsRing =>
|
||||
if let some rhsRing := rhsRoot.ring? then
|
||||
Arith.CommRing.processNewEq lhsRing rhsRing
|
||||
return {}
|
||||
return .eq lhsRing rhsRing
|
||||
else
|
||||
-- We have to retrieve the node because other fields have been updated
|
||||
let rhsRoot ← getENode rhsRoot.self
|
||||
setENode rhsRoot.self { rhsRoot with ring? := lhsRing }
|
||||
getParents rhsRoot.self
|
||||
return .diseqs (← getParents rhsRoot.self)
|
||||
| none =>
|
||||
if rhsRoot.ring?.isSome then
|
||||
getParents lhsRoot.self
|
||||
return .diseqs (← getParents lhsRoot.self)
|
||||
else
|
||||
return {}
|
||||
return .none
|
||||
|
||||
def propagateCommRing : PendingTheoryPropagation → GoalM Unit
|
||||
| .eq lhs rhs => Arith.CommRing.processNewEq lhs rhs
|
||||
| .diseqs ps => propagateCommRingDiseqs ps
|
||||
| _ => return ()
|
||||
|
||||
/--
|
||||
Tries to apply beta-reductiong using the parent applications of the functions in `fns` with
|
||||
|
|
@ -262,9 +303,9 @@ where
|
|||
}
|
||||
propagateBeta lams₁ fns₁
|
||||
propagateBeta lams₂ fns₂
|
||||
propagateOffsetEq rhsRoot lhsRoot
|
||||
let parentsToPropagateCutsatDiseqs ← propagateCutsatEq rhsRoot lhsRoot
|
||||
let parentsToPropagateRingDiseqs ← propagateCommRingEq rhsRoot lhsRoot
|
||||
let offsetTodo ← checkOffsetEq rhsRoot lhsRoot
|
||||
let cutsatTodo ← checkCutsatEq rhsRoot lhsRoot
|
||||
let ringTodo ← checkCommRingEq rhsRoot lhsRoot
|
||||
resetParentsOf lhsRoot.self
|
||||
copyParentsTo parents rhsNode.root
|
||||
unless (← isInconsistent) do
|
||||
|
|
@ -274,8 +315,9 @@ where
|
|||
propagateUp parent
|
||||
for e in toPropagateDown do
|
||||
propagateDown e
|
||||
propagateCutsatDiseqs parentsToPropagateCutsatDiseqs
|
||||
propagateCommRingDiseqs parentsToPropagateRingDiseqs
|
||||
propagateOffset offsetTodo
|
||||
propagateCutsat cutsatTodo
|
||||
propagateCommRing ringTodo
|
||||
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
|
||||
traverseEqc lhs fun n =>
|
||||
setENode n.self { n with root := rootNew }
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
open Array
|
||||
|
||||
set_option grind.warning false
|
||||
reset_grind_attrs%
|
||||
|
||||
attribute [grind] Vector.getElem_swap_of_ne
|
||||
|
|
@ -15,15 +15,6 @@ theorem qpartition_loop_spec₁ {n} (lt : α → α → Bool) (lo hi : Nat)
|
|||
∀ k, (h₁ : lo ≤ k) → (h₂ : k < mid) → lt as'[k] as'[mid] := by
|
||||
sorry
|
||||
|
||||
/--
|
||||
warning: The `grind` tactic is experimental and still under development. Avoid using it in production projects.
|
||||
---
|
||||
error: internal `grind` error, failed to build disequality proof for
|
||||
(lo + hi) / 2
|
||||
and
|
||||
lo
|
||||
-/
|
||||
#guard_msgs in
|
||||
example {n} (lt : α → α → Bool) (lo hi : Nat)
|
||||
(hlo : lo < n := by omega) (hhi : hi < n := by omega) (w : lo ≤ hi := by omega)
|
||||
(as : Vector α n) (mid as')
|
||||
Loading…
Add table
Reference in a new issue