fix: equality propagation and simplification in the comm ring procedure (#8137)

This PR improves equality propagation (also known as theory combination)
and polynomial simplification for rings that do not implement the
`NoZeroNatDivisors` class. With these fixes, `grind` can now solve:
```lean
example [CommRing α] (a b c : α) (f : α → Nat)
  : a + b + c = 3 →
    a^2 + b^2 + c^2 = 5 →
    a^3 + b^3 + c^3 = 7 →
    f (a^4 + b^4) + f (9 - c^4) ≠ 1 := by
  grind +ring
```
This example uses the commutative ring procedure, the linear integer
arithmetic solver, and congruence closure.
For rings that implement `NoZeroNatDivisors`, a polynomial is now also
divided by the greatest common divisor (gcd) of its coefficients when it
is inserted into the basis.
This commit is contained in:
Leonardo de Moura 2025-04-27 17:55:18 -07:00 committed by GitHub
parent b77e9edd44
commit 2ba021ecc2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 133 additions and 14 deletions

View file

@ -46,12 +46,13 @@ Remark: if the current ring does not satisfy the property
then the leading coefficient of the equation must also divide `k`
-/
def _root_.Lean.Grind.CommRing.Mon.findSimp? (k : Int) (m : Mon) : RingM (Option EqCnstr) := do
let checkCoeff ← checkCoeffDvd
let noZeroDiv ← noZeroDivisors
let rec go : Mon → RingM (Option EqCnstr)
| .unit => return none
| .mult pw m' => do
for c in (← getRing).varToBasis[pw.x]! do
if noZeroDiv || (c.p.lc k) then
if !checkCoeff || noZeroDiv || (c.p.lc k) then
if c.p.divides m then
return some c
go m'
@ -156,7 +157,6 @@ def EqCnstr.simplifyAndCheck (c : EqCnstr) : RingM (Option EqCnstr) := do
def addToBasisCore (c : EqCnstr) : RingM Unit := do
let .add _ m _ := c.p | return ()
let .mult pw _ := m | return ()
trace_goal[grind.ring.assert.basis] "{← c.denoteExpr}"
modifyRing fun s => { s with
varToBasis := s.varToBasis.modify pw.x (c :: ·)
recheck := true
@ -207,6 +207,9 @@ if the ring has a nonzero characteristic `p` and `gcd k p = 1`, then
`k` has an inverse.
It also handles the easy case where `k` is `-1`.
Remark: if the ring implements the class `NoZeroNatDivisors`, then
the coefficients are divided by the gcd of all coefficients.
-/
def EqCnstr.toMonic (c : EqCnstr) : RingM EqCnstr := do
let k := c.p.lc
@ -218,17 +221,20 @@ def EqCnstr.toMonic (c : EqCnstr) : RingM EqCnstr := do
-- `α*k = 1 (mod p)`
let α := if α < 0 then α % p else α
return { c with p := c.p.mulConstC α p, h := .mul α c }
else
return c
else if k == -1 then
if (← noZeroDivisors) then
let g : Int := c.p.gcdCoeffs
if g != 1 then
let g := if k < 0 then -g else g
return { c with p := c.p.divConst g, h := .div g c }
if k < 0 then
return { c with p := c.p.mulConst (-1), h := .mul (-1) c }
else
return c
return c
def EqCnstr.addToBasisAfterSimp (c : EqCnstr) : RingM Unit := do
let c ← c.toMonic
c.simplifyBasis
c.superposeWith
trace_goal[grind.ring.assert.basis] "{← c.denoteExpr}"
addToBasisCore c
def EqCnstr.addToBasis (c : EqCnstr) : RingM Unit := do
@ -251,8 +257,10 @@ def DiseqCnstr.checkConstant (c : DiseqCnstr) : RingM Bool := do
trace_goal[grind.ring.assert.trivial] "{← c.denoteExpr}"
return true
def DiseqCnstr.simplify (c : DiseqCnstr) : RingM DiseqCnstr := do
return { c with d := (← c.d.simplify) }
def DiseqCnstr.simplify (c : DiseqCnstr) : RingM DiseqCnstr :=
withCheckCoeffDvd do
-- We must enable `checkCoeffDvd := true`. See comments at `PolyDerivation`.
return { c with d := (← c.d.simplify) }
def saveDiseq (c : DiseqCnstr) : RingM Unit := do
trace_goal[grind.ring.assert.store] "{← c.denoteExpr}"
@ -316,9 +324,15 @@ private def propagateEqs : RingM Unit := do
TODO: optimize
-/
let mut map : PropagateEqMap := {}
for a in (← getRing).vars do
if (← checkMaxSteps) then return ()
let some ra ← toRingExpr? a | unreachable!
map ← process map a ra
for (a, ra) in (← getRing).denote do
if (← checkMaxSteps) then return ()
let a := a.expr
map ← process map a.expr ra
where
process (map : PropagateEqMap) (a : Expr) (ra : RingExpr) : RingM PropagateEqMap := do
let d : PolyDerivation := .input (← ra.toPolyM)
let d ← d.simplify
let k := d.getMultiplier
@ -329,10 +343,17 @@ private def propagateEqs : RingM Unit := do
let p ← (ra.sub rb).toPolyM
let d : PolyDerivation := .input p
let d ← d.simplify
if d.getMultiplier != 1 then
unless (← noZeroDivisors) do
-- Given the multipiler `k' = d.getMultiplier`, we have that `k*(a - b) = 0`,
-- but we cannot eliminate the `k` because we don't have `noZeroDivisors`.
trace_goal[grind.ring.impEq] "skip: {← mkEq a b}, k: {k}, noZeroDivisors: false"
return map.insert (k, d.p) (a, ra)
trace_goal[grind.ring.impEq] "{← mkEq a b}, {k}, {← p.denoteExpr}"
propagateEq a b ra rb d
return map
else
map := map.insert (k, d.p) (a, ra)
return map.insert (k, d.p) (a, ra)
def checkRing : RingM Bool := do
unless (← needCheck) do return false

View file

@ -195,4 +195,20 @@ def Poly.checkNoUnitMon : Poly → Bool
| .add _ .unit _ => false
| .add _ _ p => p.checkNoUnitMon
def Poly.gcdCoeffs : Poly → Nat
| .num k => k.natAbs
| .add k _ p => go p k.natAbs
where
go (p : Poly) (acc : Nat) : Nat :=
if acc == 1 then
acc
else match p with
| .num k => Nat.gcd acc k.natAbs
| .add k _ p => go p (Nat.gcd acc k.natAbs)
def Poly.divConst (p : Poly) (a : Int) : Poly :=
match p with
| .num k => .num (k / a)
| .add k m p => .add (k / a) m (divConst p a)
end Lean.Grind.CommRing

View file

@ -21,14 +21,29 @@ def checkMaxSteps : GoalM Bool := do
def incSteps : GoalM Unit := do
modify' fun s => { s with steps := s.steps + 1 }
structure RingM.Context where
ringId : Nat
/--
If `checkCoeffDvd` is `true`, then when using a polynomial `k*m - p`
to simplify `.. + k'*m*m_2 + ...`, the substitution is performed IF
- `k` divides `k'`, OR
- Ring implements `NoZeroNatDivisors`.
We need this check when simplifying disequalities. In this case, if we perform
the simplification anyway, we may end up with a proof that `k * q = 0`, but
we cannot deduce `q = 0` since the ring does not implement `NoZeroNatDivisors`
See comment at `PolyDerivation`.
-/
checkCoeffDvd : Bool := false
/-- We don't want to keep carrying the `RingId` around. -/
abbrev RingM := ReaderT Nat GoalM
abbrev RingM := ReaderT RingM.Context GoalM
abbrev RingM.run (ringId : Nat) (x : RingM α) : GoalM α :=
x ringId
x { ringId }
abbrev getRingId : RingM Nat :=
read
return (← read).ringId
def getRing : RingM Ring := do
let s ← get'
@ -42,6 +57,12 @@ def getRing : RingM Ring := do
let ringId ← getRingId
modify' fun s => { s with rings := s.rings.modify ringId f }
abbrev withCheckCoeffDvd (x : RingM α) : RingM α :=
withReader (fun ctx => { ctx with checkCoeffDvd := true }) x
def checkCoeffDvd : RingM Bool :=
return (← read).checkCoeffDvd
def getTermRingId? (e : Expr) : GoalM (Option Nat) := do
return (← get').exprToRingId.find? { expr := e }

View file

@ -65,3 +65,64 @@ example [CommRing α] (a b c : α)
a^3 + b^3 + c^3 = 7 →
a^4 + b^4 + c^4 = 9 := by
grind +ring
/--
info: [grind.ring.assert.basis] a + b + c + -3 = 0
[grind.ring.assert.basis] 2 * b ^ 2 + 2 * (b * c) + 2 * c ^ 2 + -6 * b + -6 * c + 4 = 0
[grind.ring.assert.basis] 6 * c ^ 3 + -18 * c ^ 2 + 12 * c + 4 = 0
-/
#guard_msgs (info) in
example [CommRing α] (a b c : α)
: a + b + c = 3 →
a^2 + b^2 + c^2 = 5 →
a^3 + b^3 + c^3 = 7 →
a^4 + b^4 = 9 - c^4 := by
set_option trace.grind.ring.assert.basis true in
grind +ring
/--
info: [grind.ring.assert.basis] a + b + c + -3 = 0
[grind.ring.assert.basis] b ^ 2 + b * c + c ^ 2 + -3 * b + -3 * c + 2 = 0
[grind.ring.assert.basis] 3 * c ^ 3 + -9 * c ^ 2 + 6 * c + 2 = 0
-/
#guard_msgs (info) in
example [CommRing α] [NoZeroNatDivisors α] (a b c : α)
: a + b + c = 3 →
a^2 + b^2 + c^2 = 5 →
a^3 + b^3 + c^3 = 7 →
a^4 + b^4 = 9 - c^4 := by
set_option trace.grind.ring.assert.basis true in
grind +ring
example [CommRing α] (a b : α) (f : α → Nat) : a - b = 0 → f a = f b := by
grind +ring
example (a b : BitVec 8) (f : BitVec 8 → Nat) : a - b = 0 → f a = f b := by
grind +ring
example (a b c : BitVec 8) (f : BitVec 8 → Nat) : c = 255 → - a + b - 1 = c → f a = f b := by
grind +ring
example (a b c : BitVec 8) (f : BitVec 8 → Nat) : c = 255 → - a + b - 1 = c → f (2*a) = f (b + a) := by
grind +ring
/-- info: [grind.ring.impEq] skip: b = a, k: 2, noZeroDivisors: false -/
#guard_msgs (info) in
example (a b c : BitVec 8) (f : BitVec 8 → Nat) : 2*a = 1 → 2*b = 1 → f (a) = f (b) := by
set_option trace.grind.ring.impEq true in
fail_if_success grind +ring
sorry
example (a b c : Int) (f : Int → Nat)
: a + b + c = 3 →
a^2 + b^2 + c^2 = 5 →
a^3 + b^3 + c^3 = 7 →
f (a^4 + b^4) + f (9 - c^4) ≠ 1 := by
grind +ring
example [CommRing α] (a b c : α) (f : α → Nat)
: a + b + c = 3 →
a^2 + b^2 + c^2 = 5 →
a^3 + b^3 + c^3 = 7 →
f (a^4 + b^4) + f (9 - c^4) ≠ 1 := by
grind +ring