feat: linear integer inequality normalization using gcd of coefficients (#7030)

This PR adds completes the linear integer inequality normalizer for
`grind`. The missing normalization step replaces a linear inequality of
the form `a_1*x_1 + ... + a_n*x_n + b <= 0` with `a_1/k * x_1 + ... +
a_n/k * x_n + ceil(b/k) <= 0` where `k = gcd(a_1, ..., a_n)`.
`ceil(b/k)` is implemented using the helper `cdiv b k`.
This commit is contained in:
Leonardo de Moura 2025-02-10 19:45:25 -08:00 committed by GitHub
parent e7fa5891ea
commit befee896b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 167 additions and 20 deletions

View file

@ -98,8 +98,53 @@ def PolyCnstr.denote (ctx : Context) : PolyCnstr → Prop
| .eq p => p.denote ctx = 0
| .le p => p.denote ctx ≤ 0
def cdiv (a b : Int) : Int :=
-((-a)/b)
def cmod (a b : Int) : Int :=
-((-a)%b)
theorem cdiv_add_cmod (a b : Int) : b*(cdiv a b) + cmod a b = a := by
unfold cdiv cmod
have := Int.ediv_add_emod (-a) b
have := congrArg (Neg.neg) this
simp at this
conv => rhs; rw[← this]
rw [Int.neg_add, ←Int.neg_mul, Int.neg_mul_comm]
theorem cmod_gt_of_pos (a : Int) {b : Int} (h : 0 < b) : cmod a b > -b :=
Int.neg_lt_neg (Int.emod_lt_of_pos (-a) h)
theorem cmod_nonpos (a : Int) {b : Int} (h : b ≠ 0) : cmod a b ≤ 0 := by
have := Int.neg_le_neg (Int.emod_nonneg (-a) h)
simp at this
assumption
theorem cmod_eq_zero_iff_emod_eq_zero (a b : Int) : cmod a b = 0 ↔ a%b = 0 := by
unfold cmod
have := @Int.emod_eq_emod_iff_emod_sub_eq_zero b b a
simp at this
simp [Int.neg_emod, ← this, Eq.comm]
theorem cdiv_eq_div_of_divides {a b : Int} (h : (a/b)*b = a) : a/b = cdiv a b := by
have hz : a % b = 0 := by
have := Int.ediv_add_emod a b
conv at this => rhs; rw [← Int.add_zero a]
rw [Int.mul_comm, h] at this
exact Int.add_left_cancel this
have hcz : cmod a b = 0 := cmod_eq_zero_iff_emod_eq_zero a b |>.mpr hz
have : (cdiv a b)*b = a := by
have := cdiv_add_cmod a b
simp [hcz] at this
rw [Int.mul_comm] at this
assumption
have : (a/b)*b = (cdiv a b)*b := Eq.trans h this.symm
by_cases h : b = 0
next => simp[cdiv, h]
next => rw [Int.mul_eq_mul_right_iff h] at this; assumption
def Poly.div (k : Int) : Poly → Poly
| .num k' => .num (k'/k)
| .num k' => .num (cdiv k' k)
| .add k' x p => .add (k'/k) x (div k p)
def Poly.divAll (k : Int) : Poly → Bool
@ -119,8 +164,14 @@ def PolyCnstr.norm : PolyCnstr → PolyCnstr
| .le p => .le p.norm
def PolyCnstr.divAll (k : Int) : PolyCnstr → Bool
| .eq p => p.divAll k
| .le p => p.divAll k
| .eq p | .le p => p.divAll k
def PolyCnstr.divCoeffs (k : Int) : PolyCnstr → Bool
| .eq p | .le p => p.divCoeffs k
def PolyCnstr.isLe : PolyCnstr → Bool
| .eq _ => false
| .le _ => true
def PolyCnstr.div (k : Int) : PolyCnstr → PolyCnstr
| .eq p => .eq <| p.div k
@ -131,6 +182,10 @@ inductive ExprCnstr where
| le (p₁ p₂ : Expr)
deriving Inhabited, BEq
def ExprCnstr.isLe : ExprCnstr → Bool
| .eq .. => false
| .le .. => true
def ExprCnstr.denote (ctx : Context) : ExprCnstr → Prop
| .eq e₁ e₂ => e₁.denote ctx = e₂.denote ctx
| .le e₁ e₂ => e₁.denote ctx ≤ e₂.denote ctx
@ -177,7 +232,7 @@ attribute [local simp] Poly.div Poly.divAll PolyCnstr.denote
theorem Poly.denote_div_eq_of_divAll (ctx : Context) (p : Poly) (k : Int) : p.divAll k → (p.div k).denote ctx * k = p.denote ctx := by
induction p with
| num _ => simp
| num _ => simp; intro h; rw [← cdiv_eq_div_of_divides h]; assumption
| add k' v p ih =>
simp; intro h₁ h₂
have ih := ih h₂
@ -187,9 +242,9 @@ theorem Poly.denote_div_eq_of_divAll (ctx : Context) (p : Poly) (k : Int) : p.di
attribute [local simp] Poly.divCoeffs Poly.getConst
theorem Poly.denote_div_eq_of_divCoeffs (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → (p.div k).denote ctx * k + p.getConst % k = p.denote ctx := by
theorem Poly.denote_div_eq_of_divCoeffs (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → (p.div k).denote ctx * k + cmod p.getConst k = p.denote ctx := by
induction p with
| num k' => simp; rw [Int.add_comm, Int.mul_comm, Int.ediv_add_emod]
| num k' => simp; rw [Int.mul_comm, cdiv_add_cmod]
| add k' v p ih =>
simp; intro h₁ h₂
rw [← ih h₂]
@ -305,7 +360,7 @@ attribute [local simp] PolyCnstr.divAll PolyCnstr.div
theorem ExprCnstr.eq_of_toPoly_eq_of_divBy' (ctx : Context) (e e' : ExprCnstr) (p : PolyCnstr) (k : Int) : k > 0 → p.divAll k → e.toPoly = p → e'.toPoly = p.div k → e.denote ctx = e'.denote ctx := by
intro h₀ h₁ h₂ h₃
have hz : k ≠ 0 := by intro h; simp [h] at h₀
have hz : k ≠ 0 := Int.ne_of_gt h₀
cases p <;> simp at h₁
next p =>
replace h₁ := Poly.denote_div_eq_of_divAll ctx p k h₁
@ -329,12 +384,67 @@ theorem ExprCnstr.eq_of_toPoly_eq_of_divBy' (ctx : Context) (e e' : ExprCnstr) (
rw [denote_toPoly, denote_toPoly] at this
exact this
theorem ExprCnstr.eq_of_toPoly_eq_of_divBy (ctx : Context) (e e' : ExprCnstr) (k : Int) : divBy e e' k → e.denote ctx = e'.denote ctx := by
theorem ExprCnstr.eq_of_divBy (ctx : Context) (e e' : ExprCnstr) (k : Int) : divBy e e' k → e.denote ctx = e'.denote ctx := by
intro h
simp only [divBy, Bool.and_eq_true, bne_iff_ne, ne_eq, beq_iff_eq, decide_eq_true_eq] at h
have ⟨⟨h₁, h₂⟩, h₃⟩ := h
exact ExprCnstr.eq_of_toPoly_eq_of_divBy' ctx e e' e.toPoly k h₁ h₂ rfl h₃
private theorem mul_add_cmod_le_iff {a k b : Int} (h : k > 0) : a*k + cmod b k ≤ 0 ↔ a ≤ 0 := by
constructor
· intro h'
have h₁ : a*k ≤ -cmod b k := by
have := Int.le_sub_right_of_add_le h'
simp at this
assumption
have h₂ : -cmod b k < k := by
have := cmod_gt_of_pos b h
have := Int.neg_lt_neg this
simp at this
assumption
have h₃ : a*k < k := Int.lt_of_le_of_lt h₁ h₂
have h₄ : a < 1 := by
conv at h₃ => rhs; rw [← Int.one_mul k]
have := Int.lt_of_mul_lt_mul_right h₃ (Int.le_of_lt h)
assumption
exact Int.le_of_lt_add_one (h₄ : a < 0 + 1)
· intro h'
have : a * k ≤ 0 := Int.mul_nonpos_of_nonpos_of_nonneg h' (Int.le_of_lt h)
have := Int.add_le_add this (cmod_nonpos b (Int.ne_of_gt h))
simp at this
assumption
theorem ExprCnstr.eq_of_toPoly_eq_of_divCoeffs (ctx : Context) (e e' : ExprCnstr) (p : PolyCnstr) (k : Int) : k > 0 → p.divCoeffs k → p.isLe → e.toPoly = p → e'.toPoly = p.div k → e.denote ctx = e'.denote ctx := by
intro h₀ h₁ h₂ h₃ h₄
have hz : k ≠ 0 := Int.ne_of_gt h₀
cases p <;> simp [PolyCnstr.isLe] at h₂
clear h₂
next p =>
simp [PolyCnstr.divCoeffs] at h₁
replace h₁ := Poly.denote_div_eq_of_divCoeffs ctx p k h₁
replace h₃ := congrArg (PolyCnstr.denote ctx) h₃
simp only [PolyCnstr.denote.eq_2, ← h₁] at h₃
replace h₄ := congrArg (PolyCnstr.denote ctx) h₄
simp only [PolyCnstr.denote.eq_2, PolyCnstr.div] at h₄
rw [denote_toPoly] at h₃ h₄
rw [h₃, h₄]
apply propext
apply mul_add_cmod_le_iff
exact h₀
-- Certificate for normalizing the coefficients of inequality constraint with bound tightening
def divByLe (e e' : ExprCnstr) (k : Int) : Bool :=
k > 0 && e.isLe && e.toPoly.divCoeffs k && e'.toPoly == e.toPoly.div k
theorem ExprCnstr.eq_of_divByLe (ctx : Context) (e e' : ExprCnstr) (k : Int) : divByLe e e' k → e.denote ctx = e'.denote ctx := by
intro h
simp only [divByLe, Bool.and_eq_true, bne_iff_ne, ne_eq, beq_iff_eq, decide_eq_true_eq] at h
have ⟨⟨⟨h₀, h₁⟩, h₂⟩, h₃⟩ := h
have hle : e.toPoly.isLe := by
cases e <;> simp [ExprCnstr.isLe] at h₁
simp [PolyCnstr.isLe]
apply ExprCnstr.eq_of_toPoly_eq_of_divCoeffs ctx e e' e.toPoly k h₀ h₂ hle rfl h₃
def PolyCnstr.isUnsat : PolyCnstr → Bool
| .eq (.num k) => k != 0
| .eq _ => false
@ -351,10 +461,10 @@ theorem ExprCnstr.eq_false_of_isUnsat (ctx : Context) (c : ExprCnstr) (h : c.toP
assumption
def PolyCnstr.isUnsatCoeff (k : Int) : PolyCnstr → Bool
| .eq p => p.divCoeffs k && k > 0 && p.getConst % k > 0
| .eq p => p.divCoeffs k && k > 0 && cmod p.getConst k < 0
| .le _ => false
private theorem contra {a b k : Int} (h₀ : 0 < k) (h₁ : 0 < b) (h₂ : b < k) (h₃ : a*k + b = 0) : False := by
private theorem contra_old {a b k : Int} (h₀ : 0 < k) (h₁ : 0 < b) (h₂ : b < k) (h₃ : a*k + b = 0) : False := by
have : b = -a*k := by
rw [← Int.neg_eq_of_add_eq_zero h₃, Int.neg_mul]
rw [this] at h₁ h₂
@ -366,14 +476,29 @@ private theorem contra {a b k : Int} (h₀ : 0 < k) (h₁ : 0 < b) (h₂ : b < k
have : (1 : Int) < 1 := Int.lt_of_le_of_lt low high
contradiction
private theorem PolyCnstr.eq_false (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → k > 0 → p.getConst % k > 0 → (PolyCnstr.eq p).denote ctx = False := by
private theorem contra {a b k : Int} (h₀ : 0 < k) (h₁ : -k < b) (h₂ : b < 0) (h₃ : a*k + b = 0) : False := by
have : b = -a*k := by
rw [← Int.neg_eq_of_add_eq_zero h₃, Int.neg_mul]
rw [this, Int.neg_mul] at h₁ h₂
replace h₁ := Int.lt_of_neg_lt_neg h₁
replace h₂ : -(a*k) < -0 := h₂
replace h₂ := Int.lt_of_neg_lt_neg h₂
replace h₁ : a * k < 1 * k := by simp [h₁]
replace h₁ : a < 1 := Int.lt_of_mul_lt_mul_right h₁ (Int.le_of_lt h₀)
replace h₂ : 0 * k < a * k := by simp [h₂]
replace h₂ : 0 < a := Int.lt_of_mul_lt_mul_right h₂ (Int.le_of_lt h₀)
replace h₂ : 1 ≤ a := h₂
have : (1 : Int) < 1 := Int.lt_of_le_of_lt h₂ h₁
contradiction
private theorem PolyCnstr.eq_false (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → k > 0 → cmod p.getConst k < 0 → (PolyCnstr.eq p).denote ctx = False := by
simp
intro h₁ h₂ h₃ h
have hnz : k ≠ 0 := by intro h; rw [h] at h₂; contradiction
have := Poly.denote_div_eq_of_divCoeffs ctx p k h₁
rw [h] at this
have low := h₃
have high := Int.emod_lt_of_pos p.getConst h₂
have low := cmod_gt_of_pos p.getConst h₂
have high := h₃
exact contra h₂ low high this
theorem ExprCnstr.eq_false_of_isUnsat_coeff (ctx : Context) (c : ExprCnstr) (k : Int) : c.toPoly.isUnsatCoeff k → c.denote ctx = False := by

View file

@ -69,25 +69,26 @@ def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_const) (toContextExpr atoms) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
| _ =>
let defaultK := do
let k := p.gcdCoeffs
if k == 1 then
let r ← c'.toArith atoms
let h := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr c) (toExpr c') reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
let k := p.gcdCoeffs
if k == 1 then
defaultK
else if p.getConst % k == 0 then
let c' : LinearCnstr := (p.div k).toExprCnstr
let r ← c'.toArith atoms
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_of_divBy) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_divBy) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else if p.isEq then
let r := mkConst ``False
let h := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat_coeff) (toContextExpr atoms) (toExpr c) (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else
-- TODO: tight the bound
defaultK
-- `p.isLe`: tighten the bound
let c' : LinearCnstr := (p.div k).toExprCnstr
let r ← c'.toArith atoms
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_divByLe) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else
return none

View file

@ -235,3 +235,24 @@ example (x y : Int) (h : x + x + x = 1 + 2*y + x) : False := by
example (x : Int) (h : -x - x = 1) : False := by
simp +arith only at h
example (x : Int) (h : 2*x ≤ 1) : x ≤ 0 := by
simp +arith only at h
guard_hyp h : x ≤ 0
assumption
example (x y : Int) (h : 6*x + y + y + y ≤ 7) : 2*x + y + -2 ≤ 0 := by
simp +arith only at h
guard_hyp h : 2*x + y + -2 ≤ 0
assumption
example (x y : Int) (h : 5*x + y + y + y ≤ 7 - x) : 2*x + y + -2 ≤ 0 := by
simp +arith only at h
guard_hyp h : 2*x + y + -2 ≤ 0
assumption
example (x : Int) : (11*x ≤ 10) ↔ (x ≤ 0) := by
simp +arith only
example (x : Int) : (11*x > 10) ↔ (x ≥ 1) := by
simp +arith only