feat: basic equality propagation for IntModule in grind (#11677)

This PR adds basic support for equality propagation in `grind linarith`
for the `IntModule` case. This covers only the basic case. See note in
the code.
We remark this feature is irrelevant for `CommRing` since `grind ring`
already has much better support for equality propagation.
This commit is contained in:
Leonardo de Moura 2025-12-14 23:40:11 +01:00 committed by GitHub
parent 429e09cd82
commit 62b900e8ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 56 additions and 0 deletions

View file

@ -554,4 +554,13 @@ theorem eq_eq_subst {α} [IntModule α] (ctx : Context α) (x : Var) (p₁ p₂
: eq_eq_subst_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx = 0 → p₃.denote' ctx = 0 := by
simp [eq_eq_subst_cert]; intro _ h₁ h₂; subst p₃; simp [h₁, h₂]
def imp_eq_cert (p : Poly) (x y : Var) : Bool :=
p == .add 1 x (.add (-1) y .nil)
theorem imp_eq {α} [IntModule α] (ctx : Context α) (p : Poly) (x y : Var)
: imp_eq_cert p x y → p.denote' ctx = 0 → x.denote ctx = y.denote ctx := by
simp [imp_eq_cert]; intro; subst p; simp [Poly.denote]
rw [neg_zsmul, ← sub_eq_add_neg, one_zsmul, sub_eq_zero_iff]
simp
end Lean.Grind.Linarith

View file

@ -486,6 +486,16 @@ def setInconsistent (h : UnsatProof) : LinearM Unit := do
let h ← h.toExprProof
closeGoal h
def propagateImpEq (c : EqCnstr) : LinearM Unit := do
let .add 1 x (.add (-1) y .nil) := c.p | unreachable!
let a ← getVar x
let b ← getVar y
let h ← withProofContext do
let h ← mkIntModThmPrefix ``Grind.Linarith.imp_eq
return mkApp5 h (← mkPolyDecl c.p) (← mkVarDecl x) (← mkVarDecl y) eagerReflBoolTrue (← c.toExprProof)
let h := mkExpectedPropHint h (← mkEq a b)
pushEq a b h
/-!
A linarith proof may depend on decision variables.
We collect them and perform non chronological backtracking.

View file

@ -198,6 +198,21 @@ private def updateOccs (a : Nat) (x : Var) (c : EqCnstr) : LinearM Unit := do
for y in ys do
updateOccsAt a x c y
private def isImpliedEq (c : EqCnstr) : LinearM Bool := do
match c.p with
| .add (-1) x (.add 1 y .nil)
| .add 1 x (.add (-1) y .nil) =>
if (← isEqv (← getVar x) (← getVar y)) then return false
return true
| _ => return false
private def ensureLeadCoeffPos (c : EqCnstr) : LinearM EqCnstr := do
let .add k _ _ := c.p | return c
if k < 0 then
return { p := c.p.mul (-1), h := .neg c }
else
return c
private def EqCnstr.assert (c : EqCnstr) : LinearM Unit := do
trace[grind.linarith.assert] "{← c.denoteExpr}"
let c ← c.applySubsts
@ -207,6 +222,17 @@ private def EqCnstr.assert (c : EqCnstr) : LinearM Unit := do
let (a, x, c) ← c.norm
trace[grind.debug.linarith.subst] ">> {← getVar x}, {← c.denoteExpr}"
trace[grind.linarith.assert.store] "{← c.denoteExpr}"
/-
**Note**:
We currently only catch equalities of the form `x + -1*y = 0`
This is sufficient for catching trivial cases, but to catch all implied equalities
we need to keep a mapping from `(Poly, Int)` to `Var`. The mapping contains an entry `(p, k) ↦ x`
if `x` is an eliminated variable and there is a constraint that implies `k*x = p`.
We need this mapping to catch `k*x = p` and `k*y = p`
-/
unless (← getStruct).caseSplits do
if (← isImpliedEq c) then
propagateImpEq (← ensureLeadCoeffPos c)
modifyStruct fun s => { s with
elimEqs := s.elimEqs.set x (some c)
elimStack := x :: s.elimStack

View file

@ -0,0 +1,11 @@
example {W : Type} [Lean.Grind.IntModule W] (f : W → Nat)
(_ : ∀ (a : Int) (x : W), f (a • x) = a.natAbs * f x)
(_ : a ≠ 1) (_ : a ≠ -1) (x : W) (_ : f x = 1) :
¬ x - a • x = 0 := by
grind
example {W : Type} [Lean.Grind.IntModule W] (f : W → Nat)
(_ : ∀ (a : Int) (x : W), f (a • x) = a.natAbs * f x)
(_ : a ≠ 1) (_ : a ≠ -1) (x y : W) (_ : f x = 1) :
y ≠ x → ¬ x - a • x = 0 := by
grind