From 2ba021ecc21636989aa82aa82c99f9443f3e774a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 27 Apr 2025 17:55:18 -0700 Subject: [PATCH] fix: equality propagation and simplification in the comm ring procedure (#8137) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR improves equality propagation (also known as theory combination) and polynomial simplification for rings that do not implement the `NoZeroNatDivisors` class. With these fixes, `grind` can now solve: ```lean example [CommRing α] (a b c : α) (f : α → Nat) : a + b + c = 3 → a^2 + b^2 + c^2 = 5 → a^3 + b^3 + c^3 = 7 → f (a^4 + b^4) + f (9 - c^4) ≠ 1 := by grind +ring ``` This example uses the commutative ring procedure, the linear integer arithmetic solver, and congruence closure. For rings that implement `NoZeroNatDivisors`, a polynomial is now also divided by the greatest common divisor (gcd) of its coefficients when it is inserted into the basis. --- .../Tactic/Grind/Arith/CommRing/EqCnstr.lean | 43 +++++++++---- .../Tactic/Grind/Arith/CommRing/Poly.lean | 16 +++++ .../Tactic/Grind/Arith/CommRing/Util.lean | 27 +++++++- tests/lean/run/grind_ring_2.lean | 61 +++++++++++++++++++ 4 files changed, 133 insertions(+), 14 deletions(-) diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean index a59e3a467a..338fe04f67 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean @@ -46,12 +46,13 @@ Remark: if the current ring does not satisfy the property then the leading coefficient of the equation must also divide `k` -/ def _root_.Lean.Grind.CommRing.Mon.findSimp? (k : Int) (m : Mon) : RingM (Option EqCnstr) := do + let checkCoeff ← checkCoeffDvd let noZeroDiv ← noZeroDivisors let rec go : Mon → RingM (Option EqCnstr) | .unit => return none | .mult pw m' => do for c in (← getRing).varToBasis[pw.x]! do - if noZeroDiv || (c.p.lc ∣ k) then + if !checkCoeff || noZeroDiv || (c.p.lc ∣ k) then if c.p.divides m then return some c go m' @@ -156,7 +157,6 @@ def EqCnstr.simplifyAndCheck (c : EqCnstr) : RingM (Option EqCnstr) := do def addToBasisCore (c : EqCnstr) : RingM Unit := do let .add _ m _ := c.p | return () let .mult pw _ := m | return () - trace_goal[grind.ring.assert.basis] "{← c.denoteExpr}" modifyRing fun s => { s with varToBasis := s.varToBasis.modify pw.x (c :: ·) recheck := true @@ -207,6 +207,9 @@ if the ring has a nonzero characteristic `p` and `gcd k p = 1`, then `k` has an inverse. It also handles the easy case where `k` is `-1`. + +Remark: if the ring implements the class `NoZeroNatDivisors`, then +the coefficients are divided by the gcd of all coefficients. -/ def EqCnstr.toMonic (c : EqCnstr) : RingM EqCnstr := do let k := c.p.lc @@ -218,17 +221,20 @@ def EqCnstr.toMonic (c : EqCnstr) : RingM EqCnstr := do -- `α*k = 1 (mod p)` let α := if α < 0 then α % p else α return { c with p := c.p.mulConstC α p, h := .mul α c } - else - return c - else if k == -1 then + if (← noZeroDivisors) then + let g : Int := c.p.gcdCoeffs + if g != 1 then + let g := if k < 0 then -g else g + return { c with p := c.p.divConst g, h := .div g c } + if k < 0 then return { c with p := c.p.mulConst (-1), h := .mul (-1) c } - else - return c + return c def EqCnstr.addToBasisAfterSimp (c : EqCnstr) : RingM Unit := do let c ← c.toMonic c.simplifyBasis c.superposeWith + trace_goal[grind.ring.assert.basis] "{← c.denoteExpr}" addToBasisCore c def EqCnstr.addToBasis (c : EqCnstr) : RingM Unit := do @@ -251,8 +257,10 @@ def DiseqCnstr.checkConstant (c : DiseqCnstr) : RingM Bool := do trace_goal[grind.ring.assert.trivial] "{← c.denoteExpr}" return true -def DiseqCnstr.simplify (c : DiseqCnstr) : RingM DiseqCnstr := do - return { c with d := (← c.d.simplify) } +def DiseqCnstr.simplify (c : DiseqCnstr) : RingM DiseqCnstr := + withCheckCoeffDvd do + -- We must enable `checkCoeffDvd := true`. See comments at `PolyDerivation`. + return { c with d := (← c.d.simplify) } def saveDiseq (c : DiseqCnstr) : RingM Unit := do trace_goal[grind.ring.assert.store] "{← c.denoteExpr}" @@ -316,9 +324,15 @@ private def propagateEqs : RingM Unit := do TODO: optimize -/ let mut map : PropagateEqMap := {} + for a in (← getRing).vars do + if (← checkMaxSteps) then return () + let some ra ← toRingExpr? a | unreachable! + map ← process map a ra for (a, ra) in (← getRing).denote do if (← checkMaxSteps) then return () - let a := a.expr + map ← process map a.expr ra +where + process (map : PropagateEqMap) (a : Expr) (ra : RingExpr) : RingM PropagateEqMap := do let d : PolyDerivation := .input (← ra.toPolyM) let d ← d.simplify let k := d.getMultiplier @@ -329,10 +343,17 @@ private def propagateEqs : RingM Unit := do let p ← (ra.sub rb).toPolyM let d : PolyDerivation := .input p let d ← d.simplify + if d.getMultiplier != 1 then + unless (← noZeroDivisors) do + -- Given the multipiler `k' = d.getMultiplier`, we have that `k*(a - b) = 0`, + -- but we cannot eliminate the `k` because we don't have `noZeroDivisors`. + trace_goal[grind.ring.impEq] "skip: {← mkEq a b}, k: {k}, noZeroDivisors: false" + return map.insert (k, d.p) (a, ra) trace_goal[grind.ring.impEq] "{← mkEq a b}, {k}, {← p.denoteExpr}" propagateEq a b ra rb d + return map else - map := map.insert (k, d.p) (a, ra) + return map.insert (k, d.p) (a, ra) def checkRing : RingM Bool := do unless (← needCheck) do return false diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean index b6c5f2a146..418349de6c 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean @@ -195,4 +195,20 @@ def Poly.checkNoUnitMon : Poly → Bool | .add _ .unit _ => false | .add _ _ p => p.checkNoUnitMon +def Poly.gcdCoeffs : Poly → Nat + | .num k => k.natAbs + | .add k _ p => go p k.natAbs +where + go (p : Poly) (acc : Nat) : Nat := + if acc == 1 then + acc + else match p with + | .num k => Nat.gcd acc k.natAbs + | .add k _ p => go p (Nat.gcd acc k.natAbs) + +def Poly.divConst (p : Poly) (a : Int) : Poly := + match p with + | .num k => .num (k / a) + | .add k m p => .add (k / a) m (divConst p a) + end Lean.Grind.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean index 5ef6c2b5f9..07474bcd1e 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean @@ -21,14 +21,29 @@ def checkMaxSteps : GoalM Bool := do def incSteps : GoalM Unit := do modify' fun s => { s with steps := s.steps + 1 } +structure RingM.Context where + ringId : Nat + /-- + If `checkCoeffDvd` is `true`, then when using a polynomial `k*m - p` + to simplify `.. + k'*m*m_2 + ...`, the substitution is performed IF + - `k` divides `k'`, OR + - Ring implements `NoZeroNatDivisors`. + + We need this check when simplifying disequalities. In this case, if we perform + the simplification anyway, we may end up with a proof that `k * q = 0`, but + we cannot deduce `q = 0` since the ring does not implement `NoZeroNatDivisors` + See comment at `PolyDerivation`. + -/ + checkCoeffDvd : Bool := false + /-- We don't want to keep carrying the `RingId` around. -/ -abbrev RingM := ReaderT Nat GoalM +abbrev RingM := ReaderT RingM.Context GoalM abbrev RingM.run (ringId : Nat) (x : RingM α) : GoalM α := - x ringId + x { ringId } abbrev getRingId : RingM Nat := - read + return (← read).ringId def getRing : RingM Ring := do let s ← get' @@ -42,6 +57,12 @@ def getRing : RingM Ring := do let ringId ← getRingId modify' fun s => { s with rings := s.rings.modify ringId f } +abbrev withCheckCoeffDvd (x : RingM α) : RingM α := + withReader (fun ctx => { ctx with checkCoeffDvd := true }) x + +def checkCoeffDvd : RingM Bool := + return (← read).checkCoeffDvd + def getTermRingId? (e : Expr) : GoalM (Option Nat) := do return (← get').exprToRingId.find? { expr := e } diff --git a/tests/lean/run/grind_ring_2.lean b/tests/lean/run/grind_ring_2.lean index e3232e86c5..7805344b72 100644 --- a/tests/lean/run/grind_ring_2.lean +++ b/tests/lean/run/grind_ring_2.lean @@ -65,3 +65,64 @@ example [CommRing α] (a b c : α) a^3 + b^3 + c^3 = 7 → a^4 + b^4 + c^4 = 9 := by grind +ring + +/-- +info: [grind.ring.assert.basis] a + b + c + -3 = 0 +[grind.ring.assert.basis] 2 * b ^ 2 + 2 * (b * c) + 2 * c ^ 2 + -6 * b + -6 * c + 4 = 0 +[grind.ring.assert.basis] 6 * c ^ 3 + -18 * c ^ 2 + 12 * c + 4 = 0 +-/ +#guard_msgs (info) in +example [CommRing α] (a b c : α) + : a + b + c = 3 → + a^2 + b^2 + c^2 = 5 → + a^3 + b^3 + c^3 = 7 → + a^4 + b^4 = 9 - c^4 := by + set_option trace.grind.ring.assert.basis true in + grind +ring + +/-- +info: [grind.ring.assert.basis] a + b + c + -3 = 0 +[grind.ring.assert.basis] b ^ 2 + b * c + c ^ 2 + -3 * b + -3 * c + 2 = 0 +[grind.ring.assert.basis] 3 * c ^ 3 + -9 * c ^ 2 + 6 * c + 2 = 0 +-/ +#guard_msgs (info) in +example [CommRing α] [NoZeroNatDivisors α] (a b c : α) + : a + b + c = 3 → + a^2 + b^2 + c^2 = 5 → + a^3 + b^3 + c^3 = 7 → + a^4 + b^4 = 9 - c^4 := by + set_option trace.grind.ring.assert.basis true in + grind +ring + +example [CommRing α] (a b : α) (f : α → Nat) : a - b = 0 → f a = f b := by + grind +ring + +example (a b : BitVec 8) (f : BitVec 8 → Nat) : a - b = 0 → f a = f b := by + grind +ring + +example (a b c : BitVec 8) (f : BitVec 8 → Nat) : c = 255 → - a + b - 1 = c → f a = f b := by + grind +ring + +example (a b c : BitVec 8) (f : BitVec 8 → Nat) : c = 255 → - a + b - 1 = c → f (2*a) = f (b + a) := by + grind +ring + +/-- info: [grind.ring.impEq] skip: b = a, k: 2, noZeroDivisors: false -/ +#guard_msgs (info) in +example (a b c : BitVec 8) (f : BitVec 8 → Nat) : 2*a = 1 → 2*b = 1 → f (a) = f (b) := by + set_option trace.grind.ring.impEq true in + fail_if_success grind +ring + sorry + +example (a b c : Int) (f : Int → Nat) + : a + b + c = 3 → + a^2 + b^2 + c^2 = 5 → + a^3 + b^3 + c^3 = 7 → + f (a^4 + b^4) + f (9 - c^4) ≠ 1 := by + grind +ring + +example [CommRing α] (a b c : α) (f : α → Nat) + : a + b + c = 3 → + a^2 + b^2 + c^2 = 5 → + a^3 + b^3 + c^3 = 7 → + f (a^4 + b^4) + f (9 - c^4) ≠ 1 := by + grind +ring