diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index 269e6c32c3..bf35bed0f8 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -90,75 +90,116 @@ private partial def updateMT (root : Expr) : GoalM Unit := do updateMT parent /-- -Helper function for combining `ENode.offset?` fields and propagating equalities -to the offset constraint module. +Equalities or disequalities to be propagated to a theory solver **after** +two equivalence classes have been merged. + +Some solvers (e.g. `cutsat`) require the core data structures to satisfy +their invariants. During the merge operations some of these invariants do not hold. +Thus, we first *record* the facts that must be propagated in a `PendingTheoryPropagation` value, +complete the merge, and only then perform the propagation. + +We now use this workflow for *all* theory solvers, even when a particular +solver does not rely on these invariants. This keeps the core +solver-agnostic and lets us modify solvers without further adjustments. -/ -private def propagateOffsetEq (rhsRoot lhsRoot : ENode) : GoalM Unit := do +inductive PendingTheoryPropagation where + | /-- Nothing to propagate. -/ + none + | /-- Propagate the equality `lhs = rhs`. -/ + eq (lhs rhs : Expr) + | + /-- + Propagate the literal equality `lhs = lit`. + This is needed because some solvers do not internalize literal values. + Remark: we may remove this optimization in the future because it adds complexity + for a small performance gain. + -/ + eqLit (lhs lit : Expr) + | /-- Propagate the disequalities in `ps`. -/ + diseqs (ps : ParentSet) + +/-- +Helper function for combining `ENode.offset?` fields and detecting what needs +to be propagated to the offset constraint module. +-/ +private def checkOffsetEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do match lhsRoot.offset? with | some lhsOffset => if let some rhsOffset := rhsRoot.offset? then - Arith.Offset.processNewEq lhsOffset rhsOffset + return .eq lhsOffset rhsOffset else if isNatNum rhsRoot.self then - Arith.Offset.processNewEqLit lhsOffset rhsRoot.self + return .eqLit lhsOffset rhsRoot.self else -- We have to retrieve the node because other fields have been updated let rhsRoot ← getENode rhsRoot.self setENode rhsRoot.self { rhsRoot with offset? := lhsOffset } + return .none | none => if isNatNum lhsRoot.self then - if let some rhsOffset := rhsRoot.offset? then - Arith.Offset.processNewEqLit rhsOffset lhsRoot.self + if let some rhsOffset := rhsRoot.offset? then + return .eqLit rhsOffset lhsRoot.self + return .none + +def propagateOffset : PendingTheoryPropagation → GoalM Unit + | .eq lhs rhs => Arith.Offset.processNewEq lhs rhs + | .eqLit lhs lit => Arith.Offset.processNewEqLit lhs lit + | _ => return () /-- -Helper function for combining `ENode.cutsat?` fields and propagating equalities -to the cutsat module. -It returns a set of parents that should be traversed for disequality propagation. +Helper function for combining `ENode.cutsat?` fields and detecting what needs +to be propagated to the cutsat module. -/ -private def propagateCutsatEq (rhsRoot lhsRoot : ENode) : GoalM ParentSet := do +private def checkCutsatEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do match lhsRoot.cutsat? with | some lhsCutsat => if let some rhsCutsat := rhsRoot.cutsat? then - Arith.Cutsat.processNewEq lhsCutsat rhsCutsat - return {} + return .eq lhsCutsat rhsCutsat else if isNum rhsRoot.self then - Arith.Cutsat.processNewEqLit lhsCutsat rhsRoot.self - return {} + return .eqLit lhsCutsat rhsRoot.self else -- We have to retrieve the node because other fields have been updated let rhsRoot ← getENode rhsRoot.self setENode rhsRoot.self { rhsRoot with cutsat? := lhsCutsat } - getParents rhsRoot.self + return .diseqs (← getParents rhsRoot.self) | none => if let some rhsCutsat := rhsRoot.cutsat? then if isNum lhsRoot.self then - Arith.Cutsat.processNewEqLit rhsCutsat lhsRoot.self - return {} + return .eqLit rhsCutsat lhsRoot.self else - getParents lhsRoot.self + return .diseqs (← getParents lhsRoot.self) else - return {} + return .none + +def propagateCutsat : PendingTheoryPropagation → GoalM Unit + | .eq lhs rhs => Arith.Cutsat.processNewEq lhs rhs + | .eqLit lhs lit => Arith.Cutsat.processNewEqLit lhs lit + | .diseqs ps => propagateCutsatDiseqs ps + | .none => return () /-- -Helper function for combining `ENode.ring?` fields and propagating equalities -to the commutative ring module. -It returns a set of parents that should be traversed for disequality propagation. +Helper function for combining `ENode.ring?` fields and detecting what needs to be +progagated to the commutative ring module. -/ -private def propagateCommRingEq (rhsRoot lhsRoot : ENode) : GoalM ParentSet := do +private def checkCommRingEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do match lhsRoot.ring? with | some lhsRing => if let some rhsRing := rhsRoot.ring? then - Arith.CommRing.processNewEq lhsRing rhsRing - return {} + return .eq lhsRing rhsRing else -- We have to retrieve the node because other fields have been updated let rhsRoot ← getENode rhsRoot.self setENode rhsRoot.self { rhsRoot with ring? := lhsRing } - getParents rhsRoot.self + return .diseqs (← getParents rhsRoot.self) | none => if rhsRoot.ring?.isSome then - getParents lhsRoot.self + return .diseqs (← getParents lhsRoot.self) else - return {} + return .none + +def propagateCommRing : PendingTheoryPropagation → GoalM Unit + | .eq lhs rhs => Arith.CommRing.processNewEq lhs rhs + | .diseqs ps => propagateCommRingDiseqs ps + | _ => return () /-- Tries to apply beta-reductiong using the parent applications of the functions in `fns` with @@ -262,9 +303,9 @@ where } propagateBeta lams₁ fns₁ propagateBeta lams₂ fns₂ - propagateOffsetEq rhsRoot lhsRoot - let parentsToPropagateCutsatDiseqs ← propagateCutsatEq rhsRoot lhsRoot - let parentsToPropagateRingDiseqs ← propagateCommRingEq rhsRoot lhsRoot + let offsetTodo ← checkOffsetEq rhsRoot lhsRoot + let cutsatTodo ← checkCutsatEq rhsRoot lhsRoot + let ringTodo ← checkCommRingEq rhsRoot lhsRoot resetParentsOf lhsRoot.self copyParentsTo parents rhsNode.root unless (← isInconsistent) do @@ -274,8 +315,9 @@ where propagateUp parent for e in toPropagateDown do propagateDown e - propagateCutsatDiseqs parentsToPropagateCutsatDiseqs - propagateCommRingDiseqs parentsToPropagateRingDiseqs + propagateOffset offsetTodo + propagateCutsat cutsatTodo + propagateCommRing ringTodo updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do traverseEqc lhs fun n => setENode n.self { n with root := rootNew } diff --git a/tests/lean/grind/disequality_error.lean b/tests/lean/run/grind_qpartition.lean similarity index 82% rename from tests/lean/grind/disequality_error.lean rename to tests/lean/run/grind_qpartition.lean index e1e842862b..7d4ce218fa 100644 --- a/tests/lean/grind/disequality_error.lean +++ b/tests/lean/run/grind_qpartition.lean @@ -1,5 +1,5 @@ open Array - +set_option grind.warning false reset_grind_attrs% attribute [grind] Vector.getElem_swap_of_ne @@ -15,15 +15,6 @@ theorem qpartition_loop_spec₁ {n} (lt : α → α → Bool) (lo hi : Nat) ∀ k, (h₁ : lo ≤ k) → (h₂ : k < mid) → lt as'[k] as'[mid] := by sorry -/-- -warning: The `grind` tactic is experimental and still under development. Avoid using it in production projects. ---- -error: internal `grind` error, failed to build disequality proof for - (lo + hi) / 2 -and - lo --/ -#guard_msgs in example {n} (lt : α → α → Bool) (lo hi : Nat) (hlo : lo < n := by omega) (hhi : hi < n := by omega) (w : lo ≤ hi := by omega) (as : Vector α n) (mid as')