From 62b900e8ef88d1876af00da0a87f7375aed2e83f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 14 Dec 2025 23:40:11 +0100 Subject: [PATCH] feat: basic equality propagation for `IntModule` in `grind` (#11677) This PR adds basic support for equality propagation in `grind linarith` for the `IntModule` case. This covers only the basic case. See note in the code. We remark this feature is irrelevant for `CommRing` since `grind ring` already has much better support for equality propagation. --- src/Init/Grind/Ordered/Linarith.lean | 9 +++++++ .../Meta/Tactic/Grind/Arith/Linear/Proof.lean | 10 +++++++ .../Grind/Arith/Linear/PropagateEq.lean | 26 +++++++++++++++++++ tests/lean/run/grind_intmodule_eq_prop.lean | 11 ++++++++ 4 files changed, 56 insertions(+) create mode 100644 tests/lean/run/grind_intmodule_eq_prop.lean diff --git a/src/Init/Grind/Ordered/Linarith.lean b/src/Init/Grind/Ordered/Linarith.lean index d064ad5d70..dab4a5bba8 100644 --- a/src/Init/Grind/Ordered/Linarith.lean +++ b/src/Init/Grind/Ordered/Linarith.lean @@ -554,4 +554,13 @@ theorem eq_eq_subst {α} [IntModule α] (ctx : Context α) (x : Var) (p₁ p₂ : eq_eq_subst_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx = 0 → p₃.denote' ctx = 0 := by simp [eq_eq_subst_cert]; intro _ h₁ h₂; subst p₃; simp [h₁, h₂] +def imp_eq_cert (p : Poly) (x y : Var) : Bool := + p == .add 1 x (.add (-1) y .nil) + +theorem imp_eq {α} [IntModule α] (ctx : Context α) (p : Poly) (x y : Var) + : imp_eq_cert p x y → p.denote' ctx = 0 → x.denote ctx = y.denote ctx := by + simp [imp_eq_cert]; intro; subst p; simp [Poly.denote] + rw [neg_zsmul, ← sub_eq_add_neg, one_zsmul, sub_eq_zero_iff] + simp + end Lean.Grind.Linarith diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean index 72ef45a833..b563fb70b4 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean @@ -486,6 +486,16 @@ def setInconsistent (h : UnsatProof) : LinearM Unit := do let h ← h.toExprProof closeGoal h +def propagateImpEq (c : EqCnstr) : LinearM Unit := do + let .add 1 x (.add (-1) y .nil) := c.p | unreachable! + let a ← getVar x + let b ← getVar y + let h ← withProofContext do + let h ← mkIntModThmPrefix ``Grind.Linarith.imp_eq + return mkApp5 h (← mkPolyDecl c.p) (← mkVarDecl x) (← mkVarDecl y) eagerReflBoolTrue (← c.toExprProof) + let h := mkExpectedPropHint h (← mkEq a b) + pushEq a b h + /-! A linarith proof may depend on decision variables. We collect them and perform non chronological backtracking. diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/PropagateEq.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/PropagateEq.lean index 9c6ce7c139..5a7486c0a2 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/PropagateEq.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/PropagateEq.lean @@ -198,6 +198,21 @@ private def updateOccs (a : Nat) (x : Var) (c : EqCnstr) : LinearM Unit := do for y in ys do updateOccsAt a x c y +private def isImpliedEq (c : EqCnstr) : LinearM Bool := do + match c.p with + | .add (-1) x (.add 1 y .nil) + | .add 1 x (.add (-1) y .nil) => + if (← isEqv (← getVar x) (← getVar y)) then return false + return true + | _ => return false + +private def ensureLeadCoeffPos (c : EqCnstr) : LinearM EqCnstr := do + let .add k _ _ := c.p | return c + if k < 0 then + return { p := c.p.mul (-1), h := .neg c } + else + return c + private def EqCnstr.assert (c : EqCnstr) : LinearM Unit := do trace[grind.linarith.assert] "{← c.denoteExpr}" let c ← c.applySubsts @@ -207,6 +222,17 @@ private def EqCnstr.assert (c : EqCnstr) : LinearM Unit := do let (a, x, c) ← c.norm trace[grind.debug.linarith.subst] ">> {← getVar x}, {← c.denoteExpr}" trace[grind.linarith.assert.store] "{← c.denoteExpr}" + /- + **Note**: + We currently only catch equalities of the form `x + -1*y = 0` + This is sufficient for catching trivial cases, but to catch all implied equalities + we need to keep a mapping from `(Poly, Int)` to `Var`. The mapping contains an entry `(p, k) ↦ x` + if `x` is an eliminated variable and there is a constraint that implies `k*x = p`. + We need this mapping to catch `k*x = p` and `k*y = p` + -/ + unless (← getStruct).caseSplits do + if (← isImpliedEq c) then + propagateImpEq (← ensureLeadCoeffPos c) modifyStruct fun s => { s with elimEqs := s.elimEqs.set x (some c) elimStack := x :: s.elimStack diff --git a/tests/lean/run/grind_intmodule_eq_prop.lean b/tests/lean/run/grind_intmodule_eq_prop.lean new file mode 100644 index 0000000000..a72828ef33 --- /dev/null +++ b/tests/lean/run/grind_intmodule_eq_prop.lean @@ -0,0 +1,11 @@ +example {W : Type} [Lean.Grind.IntModule W] (f : W → Nat) + (_ : ∀ (a : Int) (x : W), f (a • x) = a.natAbs * f x) + (_ : a ≠ 1) (_ : a ≠ -1) (x : W) (_ : f x = 1) : + ¬ x - a • x = 0 := by + grind + +example {W : Type} [Lean.Grind.IntModule W] (f : W → Nat) + (_ : ∀ (a : Int) (x : W), f (a • x) = a.natAbs * f x) + (_ : a ≠ 1) (_ : a ≠ -1) (x y : W) (_ : f x = 1) : + y ≠ x → ¬ x - a • x = 0 := by + grind