feat: improve support for equations in cutsat (#7203)

This PR improves the support for equalities in cutsat. It also
simplifies a few support theorems used to justify cutsat rules.
This commit is contained in:
Leonardo de Moura 2025-02-23 20:48:14 -08:00 committed by GitHub
parent 1819dc88ff
commit e7dc0d31f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 171 additions and 107 deletions

View file

@ -850,6 +850,12 @@ theorem le_unsat (ctx : Context) (p : Poly) : p.isUnsatLe → p.denote' ctx ≤
have := Int.lt_of_le_of_lt h₂ h₁
simp at this
theorem eq_norm (ctx : Context) (p₁ p₂ : Poly) (h : p₁.norm == p₂) : p₁.denote' ctx = 0 → p₂.denote' ctx = 0 := by
simp at h
replace h := congrArg (Poly.denote ctx) h
simp at h
simp [*]
def Poly.coeff (p : Poly) (x : Var) : Int :=
match p with
| .add a y p => bif x == y then a else coeff p x
@ -864,17 +870,28 @@ private theorem dvd_of_eq' {a x p : Int} : a*x + p = 0 → a p := by
rw [Int.mul_comm, ← Int.neg_mul, Eq.comm, Int.mul_comm] at h
exact ⟨-x, h⟩
private def abs (x : Int) : Int :=
Int.ofNat x.natAbs
private theorem abs_dvd {a p : Int} (h : a p) : abs a p := by
cases a <;> simp [abs]
· simp at h; assumption
· simp [Int.negSucc_eq] at h; assumption
def dvd_of_eq_cert (x : Var) (p₁ : Poly) (d₂ : Int) (p₂ : Poly) : Bool :=
d₂ == p₁.coeff x && p₂ == p₁.insert (-d₂) x
let a := p₁.coeff x
d₂ == abs a && p₂ == p₁.insert (-a) x
theorem dvd_of_eq (ctx : Context) (x : Var) (p₁ : Poly) (d₂ : Int) (p₂ : Poly)
: dvd_of_eq_cert x p₁ d₂ p₂ → p₁.denote' ctx = 0 → d₂ p₂.denote' ctx := by
simp [dvd_of_eq_cert]
intro h₁ h₂
have h := eq_add_coeff_insert ctx p₁ x
rw [← h₁, ← h₂] at h
rw [h]
apply dvd_of_eq'
rw [← h₂] at h
rw [h, h₁]
intro h₃
apply abs_dvd
apply dvd_of_eq' h₃
private theorem eq_dvd_subst' {a x p d b q : Int} : a*x + p = 0 → d b*x + q → a*d a*q - b*p := by
intro h₁ ⟨z, h₂⟩
@ -892,7 +909,7 @@ def eq_dvd_subst_cert (x : Var) (p₁ : Poly) (d₂ : Int) (p₂ : Poly) (d₃ :
let b := p₂.coeff x
let p := p₁.insert (-a) x
let q := p₂.insert (-b) x
d₃ == a * d₂ &&
d₃ == abs (a * d₂) &&
p₃ == (q.mul a |>.combine (p.mul (-b)))
theorem eq_dvd_subst (ctx : Context) (x : Var) (p₁ : Poly) (d₂ : Int) (p₂ : Poly) (d₃ : Int) (p₃ : Poly)
@ -913,124 +930,53 @@ theorem eq_dvd_subst (ctx : Context) (x : Var) (p₁ : Poly) (d₂ : Int) (p₂
rw [Int.add_comm] at h₁ h₂
have := eq_dvd_subst' h₁ h₂
rw [Int.sub_eq_add_neg, Int.add_comm] at this
apply abs_dvd
simp [this]
private theorem eq_eq_subst' {a x p b q : Int} : a*x + p = 0 → b*x + q = 0 → b*p - a*q = 0 := by
intro h₁ h₂
replace h₁ := congrArg (b*·) h₁; simp at h₁
replace h₂ := congrArg ((-a)*.) h₂; simp at h₂
rw [Int.add_comm] at h₁
replace h₁ := Int.neg_eq_of_add_eq_zero h₁
rw [← h₁]; clear h₁
replace h₂ := Int.neg_eq_of_add_eq_zero h₂; simp at h₂
rw [h₂]; clear h₂
rw [Int.mul_left_comm]
simp
def eq_eq_subst_cert (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : Bool :=
let a := p₁.coeff x
let b := p₂.coeff x
let p := p₁.insert (-a) x
let q := p₂.insert (-b) x
p₃ == (p.mul b |>.combine (q.mul (-a)))
p₃ == (p₁.mul b |>.combine (p₂.mul (-a)))
theorem eq_eq_subst (ctx : Context) (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly)
: eq_eq_subst_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx = 0 → p₃.denote' ctx = 0 := by
simp [eq_eq_subst_cert]
have eq₁ := eq_add_coeff_insert ctx p₁ x
have eq₂ := eq_add_coeff_insert ctx p₂ x
revert eq₁ eq₂
generalize p₁.coeff x = a
generalize p₂.coeff x = b
generalize p₁.insert (-a) x = p
generalize p₂.insert (-b) x = q
intro eq₁; simp [eq₁]; clear eq₁
intro eq₂; simp [eq₂]; clear eq₂
intro; subst p₃
intro h₁ h₂
rw [Int.add_comm] at h₁ h₂
have := eq_eq_subst' h₁ h₂
rw [Int.sub_eq_add_neg] at this
simp [this]
private theorem eq_le_subst_nonneg' {a x p b q : Int} : a ≥ 0 → a*x + p = 0 → b*x + q ≤ 0 → a*q - b*p ≤ 0 := by
intro h h₁ h₂
replace h₁ := congrArg ((-b)*·) h₁; simp at h₁
rw [Int.add_comm, Int.mul_left_comm] at h₁
replace h₁ := Int.neg_eq_of_add_eq_zero h₁; simp at h₁
replace h₂ := Int.mul_le_mul_of_nonneg_left h₂ h
rw [Int.mul_add, h₁] at h₂; clear h₁
simp at h₂
rw [Int.sub_eq_add_neg]
assumption
simp [*]
def eq_le_subst_nonneg_cert (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : Bool :=
let a := p₁.coeff x
let b := p₂.coeff x
let p := p₁.insert (-a) x
let q := p₂.insert (-b) x
a ≥ 0 && p₃ == (q.mul a |>.combine (p.mul (-b)))
a ≥ 0 && p₃ == (p₂.mul a |>.combine (p₁.mul (-b)))
theorem eq_le_subst_nonneg (ctx : Context) (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly)
: eq_le_subst_nonneg_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by
simp [eq_le_subst_nonneg_cert]
have eq₁ := eq_add_coeff_insert ctx p₁ x
have eq₂ := eq_add_coeff_insert ctx p₂ x
revert eq₁ eq₂
generalize p₁.coeff x = a
generalize p₂.coeff x = b
generalize p₁.insert (-a) x = p
generalize p₂.insert (-b) x = q
intro eq₁; simp [eq₁]; clear eq₁
intro eq₂; simp [eq₂]; clear eq₂
intro h
intro; subst p₃
intro h₁ h₂
rw [Int.add_comm] at h₁ h₂
have := eq_le_subst_nonneg' h h₁ h₂
rw [Int.sub_eq_add_neg, Int.add_comm] at this
simp [this]
private theorem eq_le_subst_nonpos' {a x p b q : Int} : a ≤ 0 → a*x + p = 0 → b*x + q ≤ 0 → b*p - a*q ≤ 0 := by
intro h h₁ h₂
replace h₁ := congrArg (b*·) h₁; simp at h₁
rw [Int.add_comm, Int.mul_left_comm] at h₁
replace h₁ := Int.neg_eq_of_add_eq_zero h₁; simp at h₁
replace h : (-a) ≥ 0 := by
have := Int.neg_le_neg h
simp at this
exact this
replace h₂ := Int.mul_le_mul_of_nonneg_left h₂ h; simp at h₂; clear h
rw [h₁] at h₂
rw [Int.add_comm, ←Int.sub_eq_add_neg] at h₂
assumption
replace h₂ := Int.mul_le_mul_of_nonneg_left h₂ h
simp at h₂
simp [*]
def eq_le_subst_nonpos_cert (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : Bool :=
let a := p₁.coeff x
let b := p₂.coeff x
let p := p₁.insert (-a) x
let q := p₂.insert (-b) x
a ≤ 0 && p₃ == (p.mul b |>.combine (q.mul (-a)))
a ≤ 0 && p₃ == (p₁.mul b |>.combine (p₂.mul (-a)))
theorem eq_le_subst_nonpos (ctx : Context) (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly)
: eq_le_subst_nonpos_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by
simp [eq_le_subst_nonpos_cert]
have eq₁ := eq_add_coeff_insert ctx p₁ x
have eq₂ := eq_add_coeff_insert ctx p₂ x
revert eq₁ eq₂
generalize p₁.coeff x = a
generalize p₂.coeff x = b
generalize p₁.insert (-a) x = p
generalize p₂.insert (-b) x = q
intro eq₁; simp [eq₁]; clear eq₁
intro eq₂; simp [eq₂]; clear eq₂
intro h
intro; subst p₃
intro h₁ h₂
rw [Int.add_comm] at h₁ h₂
have := eq_le_subst_nonpos' h h₁ h₂
rw [Int.sub_eq_add_neg] at this
simp [this]
simp [*]
replace h₂ := Int.mul_le_mul_of_nonpos_left h₂ h; simp at h₂; clear h
rw [← Int.neg_zero]
apply Int.neg_le_neg
rw [Int.mul_comm]
assumption
end Int.Linear

View file

@ -18,6 +18,7 @@ import Lean.Meta.Tactic.Grind.Arith.Cutsat.EqCnstr
namespace Lean
builtin_initialize registerTraceClass `grind.cutsat
builtin_initialize registerTraceClass `grind.cutsat.subst
builtin_initialize registerTraceClass `grind.cutsat.eq
builtin_initialize registerTraceClass `grind.cutsat.assert
builtin_initialize registerTraceClass `grind.cutsat.assert.dvd

View file

@ -5,11 +5,80 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Var
import Lean.Meta.Tactic.Grind.Arith.Cutsat.DvdCnstr
namespace Lean.Meta.Grind.Arith.Cutsat
def mkEqCnstr (p : Poly) (h : EqCnstrProof) : GoalM EqCnstr := do
return { p, h, id := (← mkCnstrId) }
def EqCnstr.norm (c : EqCnstr) : GoalM EqCnstr := do
let c ← if c.p.isSorted then
pure c
else
mkEqCnstr c.p.norm (.norm c)
/--
Selects the variable in the given linear polynomial whose coefficient has the smallest absolute value.
-/
def _root_.Int.Linear.Poly.pickVarToElim? (p : Poly) : Option (Int × Var) :=
match p with
| .num _ => none
| .add k x p => go k x p
where
go (k : Int) (x : Var) (p : Poly) : Int × Var :=
if k == 1 || k == -1 then
(k, x)
else match p with
| .num _ => (k, x)
| .add k' x' p =>
if k'.natAbs < k.natAbs then
go k' x' p
else
go k x p
/--
Given a polynomial `p`, returns `some (x, k, c)` if `p` contains the monomial `k*x`,
and `x` has been eliminated using the equality `c`.
-/
def _root_.Int.Linear.Poly.findVarToSubst (p : Poly) : GoalM (Option (Int × Var × EqCnstr)) := do
match p with
| .num _ => return none
| .add k x p =>
if let some c := (← get').elimEqs[x]! then
return some (k, x, c)
else
findVarToSubst p
partial def applySubsts (c : EqCnstr) : GoalM EqCnstr := do
let some (a, x, c₁) ← c.p.findVarToSubst | return c
trace[grind.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
let b := c₁.p.coeff x
let p := c.p.mul (-b) |>.combine (c₁.p.mul a)
let c ← mkEqCnstr p (.subst x c₁ c)
applySubsts c
def EqCnstr.assert (c : EqCnstr) : GoalM Unit := do
if (← isInconsistent) then return ()
trace[grind.cutsat.assert] "{← c.pp}"
let c ← c.norm
let c ← applySubsts c
-- TODO: check coeffsr
trace[grind.cutsat.eq] "{← c.pp}"
let some (k, x) := c.p.pickVarToElim? | c.throwUnexpected
-- TODO: eliminate `x` from lowers, uppers, and dvdCnstrs
-- TODO: reset `x`s occurrences
-- assert a divisibility constraint IF `|k| != 1`
if k.natAbs != 1 then
let p := c.p.insert (-k) x
let d := Int.ofNat k.natAbs
let c ← mkDvdCnstr d p (.ofEq x c)
c.assert
modify' fun s => { s with
elimEqs := s.elimEqs.set x (some c)
elimStack := x :: s.elimStack
}
@[export lean_process_cutsat_eq]
def processNewEqImpl (a b : Expr) : GoalM Unit := do
trace[grind.cutsat.eq] "{mkIntEq a b}"
@ -17,10 +86,14 @@ def processNewEqImpl (a b : Expr) : GoalM Unit := do
return ()
@[export lean_process_new_cutsat_lit]
def processNewEqLitImpl (a k : Expr) : GoalM Unit := do
trace[grind.cutsat.eq] "{mkIntEq a k}"
-- TODO
return ()
def processNewEqLitImpl (a ke : Expr) : GoalM Unit := do
let some k ← getIntValue? ke | return ()
let some p := (← get').terms.find? { expr := a } | return ()
if k == 0 then
(← mkEqCnstr p (.expr (← mkEqProof a ke))).assert
else
-- TODO
return ()
/-- Different kinds of terms internalized by this module. -/
private inductive SupportedTermKind where

View file

@ -37,8 +37,11 @@ partial def DvdCnstr.toExprProof (c' : DvdCnstr) : ProofM Expr := c'.caching do
return mkApp10 (mkConst ``Int.Linear.dvd_solve_elim)
(← getContext) (toExpr c₁.d) (toExpr c₁.p) (toExpr c₂.d) (toExpr c₂.p) (toExpr c'.d) (toExpr c'.p)
reflBoolTrue (← c₁.toExprProof) (← c₂.toExprProof)
| .subst _c₁ _c₂ => throwError "NIY"
| .ofEq _c => throwError "NIY"
| .subst _x _c₁ _c₂ => throwError "NIY"
| .ofEq x c =>
return mkApp7 (mkConst ``Int.Linear.dvd_of_eq)
(← getContext) (toExpr x) (toExpr c.p) (toExpr c'.d) (toExpr c'.p)
reflBoolTrue (← c.toExprProof)
partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := c'.caching do
match c'.h with
@ -56,7 +59,18 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := c'.caching do
(← getContext) (toExpr c₁.p) (toExpr c₂.p) (toExpr c'.p)
reflBoolTrue
(← c₁.toExprProof) (← c₂.toExprProof)
| .subst _c₁ _c₂ => throwError "NIY"
| .subst _x _c₁ _c₂ => throwError "NIY"
partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := c'.caching do
match c'.h with
| .expr h =>
return h
| .norm c =>
return mkApp5 (mkConst ``Int.Linear.eq_norm) (← getContext) (toExpr c.p) (toExpr c'.p) reflBoolTrue (← c.toExprProof)
| .subst x c₁ c₂ =>
return mkApp8 (mkConst ``Int.Linear.eq_eq_subst)
(← getContext) (toExpr x) (toExpr c₁.p) (toExpr c₂.p) (toExpr c'.p)
reflBoolTrue (← c₁.toExprProof) (← c₂.toExprProof)
end
end Lean.Meta.Grind.Arith.Cutsat

View file

@ -34,8 +34,8 @@ inductive DvdCnstrProof where
| solveCombine (c₁ c₂ : DvdCnstr)
| solveElim (c₁ c₂ : DvdCnstr)
| elim (c : DvdCnstr)
| ofEq (c : EqCnstr)
| subst (c₁ : EqCnstr) (c₂ : DvdCnstr)
| ofEq (x : Var) (c : EqCnstr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : DvdCnstr)
structure LeCnstr where
p : Poly
@ -48,7 +48,7 @@ inductive LeCnstrProof where
| norm (c : LeCnstr)
| divCoeffs (c : LeCnstr)
| combine (c₁ c₂ : LeCnstr)
| subst (c₁ : EqCnstr) (c₂ : LeCnstr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : LeCnstr)
-- TODO: missing constructors
structure EqCnstr where
@ -59,7 +59,7 @@ structure EqCnstr where
inductive EqCnstrProof where
| expr (h : Expr)
| norm (c : EqCnstr)
| subst (c₁ : EqCnstr) (c₂ : EqCnstr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : EqCnstr)
end
/-- State of the cutsat procedure. -/

View file

@ -125,7 +125,7 @@ def EqCnstr.pp (c : EqCnstr) : GoalM MessageData := do
def EqCnstr.denoteExpr (c : EqCnstr) : GoalM Expr := do
return mkIntEq (← c.p.denoteExpr') (mkIntLit 0)
def EqCnstr.throwUnexpected (c : LeCnstr) : GoalM α := do
def EqCnstr.throwUnexpected (c : EqCnstr) : GoalM α := do
throwError "`grind` internal error, unexpected{indentD (← c.pp)}"
/-- Returns occurrences of `x`. -/
@ -176,11 +176,9 @@ abbrev caching (id : Nat) (k : ProofM Expr) : ProofM Expr := do
modify fun s => { s with cache := s.cache.insert id h }
return h
abbrev DvdCnstr.caching (c : DvdCnstr) (k : ProofM Expr) : ProofM Expr :=
Cutsat.caching c.id k
abbrev LeCnstr.caching (c : LeCnstr) (k : ProofM Expr) : ProofM Expr :=
Cutsat.caching c.id k
abbrev DvdCnstr.caching (c : DvdCnstr) (k : ProofM Expr) : ProofM Expr := Cutsat.caching c.id k
abbrev LeCnstr.caching (c : LeCnstr) (k : ProofM Expr) : ProofM Expr := Cutsat.caching c.id k
abbrev EqCnstr.caching (c : EqCnstr) (k : ProofM Expr) : ProofM Expr := Cutsat.caching c.id k
abbrev withProofContext (x : ProofM Expr) : GoalM Expr := do
withLetDecl `ctx (mkApp (mkConst ``RArray) (mkConst ``Int)) (← toContextExpr) fun ctx => do

View file

@ -0,0 +1,32 @@
set_option grind.warning false
-- set_option grind.debug true -- TODO: enable after making more progress in `EqCnstr.lean`
open Int.Linear
-- set_option trace.grind.cutsat.assert true
-- set_option trace.grind.cutsat.internalize true
/-- info: [grind.cutsat.eq] b + 「f a」 + 1 = 0 -/
#guard_msgs (info) in
set_option trace.grind.cutsat.eq true in
example (a b : Int) (f : Int → Int) (h₁ : f a + b + 3 = 2) : False := by
fail_if_success grind
sorry
theorem ex₁ (a b : Int) (_ : 2*a + 3*b = 0) (_ : 2 3*b + 1) : False := by
grind
theorem ex₂ (a b : Int) (_ : 2 3*a + 1) (_ : 2*b + 3*a = 0) : False := by
grind
set_option trace.grind.cutsat.subst true
theorem ex₃ (a b c : Int) (_ : c + 3*a = 0) (_ : 2 3*a + 1) (_ : 2*b + c = 0) : False := by
grind
theorem ex₄ (a b c : Int) (_ : 2*c + 3*a = 0) (_ : 2*b + c = 0) (_ : 2 3*a + 1) : False := by
grind
#print ex₁
#print ex₂
#print ex₃
#print ex₄