feat: non-chronological backtracking for cutsat (#7284)

This PR implements non-choronological backtracking for the cutsat
procedure. The procedure has two main kinds of case-splits:
disequalities and Cooper resolvents. This PR focus on the first kind.
This commit is contained in:
Leonardo de Moura 2025-03-01 15:19:11 -08:00 committed by GitHub
parent c4d3a74f32
commit a86145b6bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 211 additions and 42 deletions

View file

@ -1020,6 +1020,9 @@ theorem diseq_coeff (ctx : Context) (p p' : Poly) (k : Int) : eq_coeff_cert p p'
simp [eq_coeff_cert]
intro _ _; simp [mul_eq_zero_iff, *]
theorem diseq_neg (ctx : Context) (p p' : Poly) : p' == p.mul (-1) → p.denote' ctx ≠ 0 → p'.denote' ctx ≠ 0 := by
simp; intro _ _; simp [mul_eq_zero_iff, *]
theorem diseq_unsat (ctx : Context) (p : Poly) : p.isUnsatDiseq → p.denote' ctx ≠ 0 → False := by
simp [Poly.isUnsatDiseq] <;> split <;> simp
@ -1067,6 +1070,24 @@ theorem le_of_le_diseq (ctx : Context) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly)
next h => have := Int.lt_of_le_of_lt h₁ h; simp at this
intro h; cases h <;> intro <;> subst p₂ p₃ <;> simp <;> apply this
def diseq_split_cert (p₁ p₂ p₃ : Poly) : Bool :=
p₂ == p₁.addConst 1 &&
p₃ == (p₁.mul (-1)).addConst 1
theorem diseq_split (ctx : Context) (p₁ p₂ p₃ : Poly)
: diseq_split_cert p₁ p₂ p₃ → p₁.denote' ctx ≠ 0 → p₂.denote' ctx ≤ 0 p₃.denote' ctx ≤ 0 := by
simp [diseq_split_cert]
intro _ _; subst p₂ p₃; simp
generalize p₁.denote ctx = p
intro h; cases Int.lt_or_gt_of_ne h
next h => have := Int.add_one_le_of_lt h; rw [Int.add_comm]; simp [*]
next h => have := Int.add_one_le_of_lt (Int.neg_lt_neg h); simp at this; simp [*]
theorem diseq_split_resolve (ctx : Context) (p₁ p₂ p₃ : Poly)
: diseq_split_cert p₁ p₂ p₃ → p₁.denote' ctx ≠ 0 → ¬p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by
intro h₁ h₂ h₃
exact (diseq_split ctx p₁ p₂ p₃ h₁ h₂).resolve_left h₃
def OrOver (n : Nat) (p : Nat → Prop) : Prop :=
match n with
| 0 => False

View file

@ -51,5 +51,7 @@ builtin_initialize registerTraceClass `grind.cutsat.diseq.trivial (inherited :=
builtin_initialize registerTraceClass `grind.debug.cutsat.eq
builtin_initialize registerTraceClass `grind.debug.cutsat.diseq
builtin_initialize registerTraceClass `grind.debug.cutsat.diseq.split
builtin_initialize registerTraceClass `grind.debug.cutsat.backtrack
end Lean

View file

@ -96,6 +96,12 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := c'.caching do
return mkApp7 (mkConst ``Int.Linear.le_of_le_diseq)
(← getContext) (toExpr c₁.p) (toExpr c₂.p) (toExpr c'.p)
reflBoolTrue (← c₁.toExprProof) (← c₂.toExprProof)
| .ofDiseqSplit c₁ fvarId h _ =>
let p₂ := c₁.p.addConst 1
let hFalse ← h.toExprProofCore
let hNot := mkLambda `h .default (mkIntLE (← p₂.denoteExpr') (mkIntLit 0)) (hFalse.abstract #[mkFVar fvarId])
return mkApp7 (mkConst ``Int.Linear.diseq_split_resolve)
(← getContext) (toExpr c₁.p) (toExpr p₂) (toExpr c'.p) reflBoolTrue (← c₁.toExprProof) hNot
partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := c'.caching do
match c'.h with
@ -108,32 +114,36 @@ partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := c'.caching
| .divCoeffs c =>
let k := c.p.gcdCoeffs c.p.getConst
return mkApp6 (mkConst ``Int.Linear.diseq_coeff) (← getContext) (toExpr c.p) (toExpr c'.p) (toExpr k) reflBoolTrue (← c.toExprProof)
| .neg c =>
return mkApp5 (mkConst ``Int.Linear.diseq_neg) (← getContext) (toExpr c.p) (toExpr c'.p) reflBoolTrue (← c.toExprProof)
| .subst x c₁ c₂ =>
return mkApp8 (mkConst ``Int.Linear.eq_diseq_subst)
(← getContext) (toExpr x) (toExpr c₁.p) (toExpr c₂.p) (toExpr c'.p)
reflBoolTrue (← c₁.toExprProof) (← c₂.toExprProof)
partial def UnsatProof.toExprProofCore (h : UnsatProof) : ProofM Expr := do
match h with
| .le c =>
trace[grind.cutsat.le.unsat] "{← c.pp}"
return mkApp4 (mkConst ``Int.Linear.le_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof)
| .dvd c =>
trace[grind.cutsat.dvd.unsat] "{← c.pp}"
return mkApp5 (mkConst ``Int.Linear.dvd_unsat) (← getContext) (toExpr c.d) (toExpr c.p) reflBoolTrue (← c.toExprProof)
| .eq c =>
trace[grind.cutsat.eq.unsat] "{← c.pp}"
if c.p.isUnsatEq then
return mkApp4 (mkConst ``Int.Linear.eq_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof)
else
let k := c.p.gcdCoeffs'
return mkApp5 (mkConst ``Int.Linear.eq_unsat_coeff) (← getContext) (toExpr c.p) (toExpr (Int.ofNat k)) reflBoolTrue (← c.toExprProof)
| .diseq c =>
trace[grind.cutsat.diseq.unsat] "{← c.pp}"
return mkApp4 (mkConst ``Int.Linear.diseq_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof)
end
def UnsatProof.toExprProof (h : UnsatProof) : GoalM Expr := do
withProofContext do
match h with
| .le c =>
trace[grind.cutsat.le.unsat] "{← c.pp}"
return mkApp4 (mkConst ``Int.Linear.le_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof)
| .dvd c =>
trace[grind.cutsat.dvd.unsat] "{← c.pp}"
return mkApp5 (mkConst ``Int.Linear.dvd_unsat) (← getContext) (toExpr c.d) (toExpr c.p) reflBoolTrue (← c.toExprProof)
| .eq c =>
trace[grind.cutsat.eq.unsat] "{← c.pp}"
if c.p.isUnsatEq then
return mkApp4 (mkConst ``Int.Linear.eq_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof)
else
let k := c.p.gcdCoeffs'
return mkApp5 (mkConst ``Int.Linear.eq_unsat_coeff) (← getContext) (toExpr c.p) (toExpr (Int.ofNat k)) reflBoolTrue (← c.toExprProof)
| .diseq c =>
trace[grind.cutsat.diseq.unsat] "{← c.pp}"
return mkApp4 (mkConst ``Int.Linear.diseq_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof)
withProofContext do h.toExprProofCore
def setInconsistent (h : UnsatProof) : GoalM Unit := do
if (← get').caseSplits then
@ -143,4 +153,70 @@ def setInconsistent (h : UnsatProof) : GoalM Unit := do
let h ← h.toExprProof
closeGoal h
/-!
A cutsat proof may depend on decision variables.
We collect them and perform non chronological backtracking.
-/
structure CollectDecVars.State where
visited : Std.HashSet Nat := {}
found : FVarIdSet := {}
abbrev CollectDecVarsM := ReaderT FVarIdSet (StateM CollectDecVars.State)
private def alreadyVisited (id : Nat) : CollectDecVarsM Bool := do
if (← get).visited.contains id then return true
modify fun s => { s with visited := s.visited.insert id }
return false
private def markAsFound (fvarId : FVarId) : CollectDecVarsM Unit := do
modify fun s => { s with found := s.found.insert fvarId }
private def collectExpr (e : Expr) : CollectDecVarsM Unit := do
let .fvar fvarId := e | return ()
if (← read).contains fvarId then
markAsFound fvarId
mutual
partial def EqCnstr.collectDecVars (c' : EqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c'.id) do
match c'.h with
| .expr h => collectExpr h
| .core .. => return () -- Equalities coming from the core never contain cutsat decision variables
| .norm c | .divCoeffs c => c.collectDecVars
| .subst _ c₁ c₂ | .ofLeGe c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
partial def DvdCnstr.collectDecVars (c' : DvdCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c'.id) do
match c'.h with
| .expr h => collectExpr h
| .norm c | .elim c | .divCoeffs c | .ofEq _ c => c.collectDecVars
| .solveCombine c₁ c₂ | .solveElim c₁ c₂ | .subst _ c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
partial def LeCnstr.collectDecVars (c' : LeCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c'.id) do
match c'.h with
| .expr h => collectExpr h
| .notExpr .. => return () -- This kind of proof is used for connecting with the `grind` core.
| .norm c | .divCoeffs c => c.collectDecVars
| .combine c₁ c₂ | .subst _ c₁ c₂ | .ofLeDiseq c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
| .ofDiseqSplit _ _ _ decVars =>
-- Recall that we cache the decision variables used in this kind of proof
for fvar in decVars do
markAsFound fvar
partial def DiseqCnstr.collectDecVars (c' : DiseqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c'.id) do
match c'.h with
| .expr h => collectExpr h
| .core .. => return () -- Disequalities coming from the core never contain cutsat decision variables
| .norm c | .divCoeffs c | .neg c => c.collectDecVars
| .subst _ c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
end
def UnsatProof.collectDecVars (h : UnsatProof) : CollectDecVarsM Unit := do
match h with
| .le c | .dvd c | .eq c | .diseq c => c.collectDecVars
abbrev CollectDecVarsM.run (x : CollectDecVarsM Unit) (decVars : FVarIdSet) : FVarIdSet :=
let (_, s) := x decVars |>.run {}
s.found
end Lean.Meta.Grind.Arith.Cutsat

View file

@ -8,6 +8,7 @@ import Lean.Meta.Tactic.Grind.Arith.Cutsat.Var
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
import Lean.Meta.Tactic.Grind.Arith.Cutsat.DvdCnstr
import Lean.Meta.Tactic.Grind.Arith.Cutsat.LeCnstr
import Lean.Meta.Tactic.Grind.Arith.Cutsat.EqCnstr
import Lean.Meta.Tactic.Grind.Arith.Cutsat.SearchM
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Model
@ -265,8 +266,28 @@ def resolveCooperDvd (c₁ c₂ : LeCnstr) (c : DvdCnstr) : GoalM Unit := do
def resolveCooperDiseq (c₁ : DiseqCnstr) (c₂ : LeCnstr) (_c? : Option DvdCnstr) : GoalM Unit := do
throwError "Cooper-diseq NIY {← c₁.pp} {← c₂.pp}"
def resolveRatDiseq (c₁ : LeCnstr) (c : DiseqCnstr) : GoalM Unit := do
throwError "diseq NIY {← c₁.pp} {← c.pp}"
/--
Given `c₁` of the form `-a₁*x + p₁ ≤ 0`, and `c` of the form `b*x + p ≠ 0`,
splits `c` and resolve with `c₁`.
Recall that a disequality
-/
def resolveRatDiseq (c₁ : LeCnstr) (c : DiseqCnstr) : SearchM Unit := do
let c ← if c.p.leadCoeff < 0 then
mkDiseqCnstr (c.p.mul (-1)) (.neg c)
else
pure c
let fvarId ← if let some fvarId := (← get').diseqSplits.find? c.p then
trace[grind.debug.cutsat.diseq.split] "{← c.pp}, reusing {fvarId.name}"
pure fvarId
else
let fvarId ← mkCase (.diseq c)
trace[grind.debug.cutsat.diseq.split] "{← c.pp}, {fvarId.name}"
modify' fun s => { s with diseqSplits := s.diseqSplits.insert c.p fvarId }
pure fvarId
let p₂ := c.p.addConst 1
let c₂ ← mkLeCnstr p₂ (.expr (mkFVar fvarId))
let b ← resolveRealLowerUpperConflict c₁ c₂
assert! b
def processVar (x : Var) : SearchM Unit := do
if (← eliminated x) then
@ -334,20 +355,47 @@ def processVar (x : Var) : SearchM Unit := do
def hasAssignment : GoalM Bool := do
return (← get').vars.size == (← get').assignment.size
private def isDone : GoalM Bool := do
if (← hasAssignment) then
private def findCase (decVars : FVarIdSet) : SearchM Case := do
repeat
let numCases := (← get).cases.size
assert! numCases > 0
let case := (← get).cases[numCases-1]!
modify fun s => { s with cases := s.cases.pop }
if decVars.contains case.fvarId then
return case
-- Conflict does not depend on this case.
trace[grind.debug.cutsat.backtrack] "skipping {case.fvarId.name}"
unreachable!
def resolveConflict (h : UnsatProof) : SearchM Bool := do
let decVars := h.collectDecVars.run (← get).decVars
if decVars.isEmpty then
closeGoal (← h.toExprProof)
return false
let c ← findCase decVars
modify' fun _ => c.saved
match c.kind with
| .diseq c₁ =>
let decVars := decVars.erase c.fvarId |>.toArray
let p' := c₁.p.mul (-1) |>.addConst 1
let c' ← mkLeCnstr p' (.ofDiseqSplit c₁ c.fvarId h decVars)
trace[grind.debug.cutsat.backtrack] "resolved diseq split: {← c'.pp}"
c'.assert
return true
if (← inconsistent) then
return true
return false
| _ => throwError "NIY resolve conflict"
/-- Search for an assignment/model for the linear constraints. -/
def searchAssigmentMain : SearchM Unit := do
repeat
if (← isDone) then
if (← hasAssignment) then
return ()
if (← isInconsistent) then
-- `grind` state is inconsistent
return ()
if let some c := (← get').conflict? then
unless (← resolveConflict c) do
return ()
let x : Var := (← get').assignment.size
-- TODO: resolve unsat conflicts
processVar x
def traceModel : GoalM Unit := do

View file

@ -14,7 +14,7 @@ In principle, we only need to support two kinds of case split.
- Cooper-Left, but we have 4 different variants of this one.
-/
inductive CaseKind where
| diseq
| diseq (d : DiseqCnstr)
| copperLeft
| copperDvdLeft
| cooperRight
@ -30,11 +30,6 @@ structure Case where
-/
fvarId : FVarId
/--
Decision variable as a Lean type. We use it to construct
the actual proof term.
-/
type : Expr
/--
Snapshot of the cutsat state for backtracking purposes.
We do not use a trail stack.
-/
@ -75,4 +70,14 @@ def isApprox : SearchM Bool :=
def setImprecise : SearchM Unit := do
modify fun s => { s with precise := false }
def mkCase (kind : CaseKind) : SearchM FVarId := do
let fvarId ← mkFreshFVarId
let saved ← get'
modify fun s => { s with
cases := s.cases.push { saved, fvarId, kind }
decVars := s.decVars.insert fvarId
}
modify' fun s => { s with caseSplits := true }
return fvarId
end Lean.Meta.Grind.Arith.Cutsat

View file

@ -15,6 +15,8 @@ namespace Lean.Meta.Grind.Arith.Cutsat
export Int.Linear (Var Poly)
export Std.Internal (Rat)
deriving instance Hashable for Poly
/-!
This module implements a model-based decision procedure for linear integer arithmetic,
inspired by Section 4 of "Cutting to the Chase: Solving Linear Integer Arithmetic".
@ -115,6 +117,7 @@ inductive LeCnstrProof where
| combine (c₁ c₂ : LeCnstr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : LeCnstr)
| ofLeDiseq (c₁ : LeCnstr) (c₂ : DiseqCnstr)
| ofDiseqSplit (c₁ : DiseqCnstr) (decVar : FVarId) (h : UnsatProof) (decVars : Array FVarId)
-- TODO: missing constructors
/-- A disequality constraint and its justification/proof. -/
@ -128,13 +131,9 @@ inductive DiseqCnstrProof where
| core (p₁ p₂ : Poly) (h : Expr)
| norm (c : DiseqCnstr)
| divCoeffs (c : DiseqCnstr)
| neg (c : DiseqCnstr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : DiseqCnstr)
end
instance : Inhabited DvdCnstr where
default := { d := 0, p := .num 0, h := .expr default, id := 0 }
/--
A proof of `False`.
Remark: We will later add support for a backtraking search inside of cutsat.
@ -145,6 +144,11 @@ inductive UnsatProof where
| eq (c : EqCnstr)
| diseq (c : DiseqCnstr)
end
instance : Inhabited DvdCnstr where
default := { d := 0, p := .num 0, h := .expr default, id := 0 }
abbrev VarSet := RBTree Var compare
/-- State of the cutsat procedure. -/
@ -209,6 +213,12 @@ structure State where
can convert `UnsatProof` into a Lean term and close the current `grind` goal.
-/
conflict? : Option UnsatProof := none
/--
Cache decision variables used when splitting on disequalities.
This is necessary because the same disequality may be in different conflicts.
-/
diseqSplits : PHashMap Poly FVarId := {}
/-
TODO: Model-based theory combination.
-/

View file

@ -46,9 +46,8 @@ def get' : GoalM State := do
/-- Returns `true` if the cutsat state is inconsistent. -/
def inconsistent : GoalM Bool := do
-- TODO: we will have a nested backtracking search in cutsat
-- and this function will have to be refined.
isInconsistent
if (← isInconsistent) then return true
return (← get').conflict?.isSome
def getVars : GoalM (PArray Expr) :=
return (← get').vars

View file

@ -0,0 +1,8 @@
set_option grind.warning false
set_option grind.debug true
open Int.Linear
theorem ex₁ (a b c : Int) : c ≥ 0 → b ≥ 0 → 1 ≤ a + c → a + b ≤ 1 → a ≠ 1 → c ≤ a → False := by
grind
#print ex₁

View file

@ -1,5 +1,5 @@
set_option grind.warning false
-- set_option grind.debug true
set_option grind.debug true
open Int.Linear
example (a b c d e : Int) :