diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean index 5adc3c0205..5f08683d03 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean @@ -45,5 +45,6 @@ builtin_initialize registerTraceClass `grind.debug.cutsat.internalize builtin_initialize registerTraceClass `grind.debug.cutsat.toInt builtin_initialize registerTraceClass `grind.debug.cutsat.search.cnstrs builtin_initialize registerTraceClass `grind.debug.cutsat.search.reorder +builtin_initialize registerTraceClass `grind.debug.cutsat.elimEq end Lean diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean index 788317988f..eee2868ba5 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean @@ -169,6 +169,16 @@ private def updateDiseqs (a : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM Uni c₂.assert if (← inconsistent) then return () +private def updateElimEqs (a : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM Unit := do + if (← inconsistent) then return () + assert! x != y + let some c₂ := (← get').elimEqs[y]! | return () + let b := c₂.p.coeff x + if b == 0 then return () + let c₂ := { p := c₂.p.mul a |>.combine (c.p.mul (-b)), h := .subst x c₂ c : EqCnstr } + trace[grind.debug.cutsat.elimEq] "updated: {← getVar y}, {← c₂.pp}" + modify' fun s => { s with elimEqs := s.elimEqs.set y (some c₂) } + private def updateOccsAt (k : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM Unit := do updateDvdCnstr k x c y updateLowers k x c y @@ -181,6 +191,29 @@ private def updateOccs (k : Int) (x : Var) (c : EqCnstr) : GoalM Unit := do updateOccsAt k x c x for y in ys do updateOccsAt k x c y + updateElimEqs k x c y + +/-- +Similar to `updateOccs`, but does not assume first variable is `p`s "owner". +Recall that when eliminating equalities we do not necessarily eliminate the +maximal variable, but the one with unit coefficient. +Remark: we keep occurrences for equations in `elimEqs` because we want to maintain them +in solved form. Consider the following scenario: +1- Asserted `a + 2*b + 1 = 0`, and eliminated `a` +2- Asserted `b + 1 = 0`, and eliminated `b`. + +At step 2, we want to substitute `b` at `a + 2*b + 1` to ensure `cutsat` knows +`a` is forced to be equal to a constant value. This is relevant for linearizing +nonlinear terms. + +Remark: `x` is the variable that was eliminated using `p`. +-/ +partial def _root_.Int.Linear.Poly.updateOccsForElimEq (p : Poly) (x : Var) : GoalM Unit := do + let rec go (p : Poly) : GoalM Unit := do + let .add _ y p := p | return () + unless x == y do addOcc y x + go p + go p @[export lean_grind_cutsat_assert_eq] def EqCnstr.assertImpl (c : EqCnstr) : GoalM Unit := do @@ -205,11 +238,13 @@ def EqCnstr.assertImpl (c : EqCnstr) : GoalM Unit := do let some (k, x) := c.p.pickVarToElim? | c.throwUnexpected trace[grind.debug.cutsat.subst] ">> {← getVar x}, {← c.pp}" trace[grind.cutsat.assert.store] "{← c.pp}" + trace[grind.debug.cutsat.elimEq] "{← getVar x}, {← c.pp}" modify' fun s => { s with elimEqs := s.elimEqs.set x (some c) elimStack := x :: s.elimStack } updateOccs k x c + c.p.updateOccsForElimEq x if (← inconsistent) then return () -- assert a divisibility constraint IF `|k| != 1` if k.natAbs != 1 then