diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean index b5b331142f..084597d600 100644 --- a/src/Init/Data/Int/Linear.lean +++ b/src/Init/Data/Int/Linear.lean @@ -1020,6 +1020,9 @@ theorem diseq_coeff (ctx : Context) (p p' : Poly) (k : Int) : eq_coeff_cert p p' simp [eq_coeff_cert] intro _ _; simp [mul_eq_zero_iff, *] +theorem diseq_neg (ctx : Context) (p p' : Poly) : p' == p.mul (-1) → p.denote' ctx ≠ 0 → p'.denote' ctx ≠ 0 := by + simp; intro _ _; simp [mul_eq_zero_iff, *] + theorem diseq_unsat (ctx : Context) (p : Poly) : p.isUnsatDiseq → p.denote' ctx ≠ 0 → False := by simp [Poly.isUnsatDiseq] <;> split <;> simp @@ -1067,6 +1070,24 @@ theorem le_of_le_diseq (ctx : Context) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) next h => have := Int.lt_of_le_of_lt h₁ h; simp at this intro h; cases h <;> intro <;> subst p₂ p₃ <;> simp <;> apply this +def diseq_split_cert (p₁ p₂ p₃ : Poly) : Bool := + p₂ == p₁.addConst 1 && + p₃ == (p₁.mul (-1)).addConst 1 + +theorem diseq_split (ctx : Context) (p₁ p₂ p₃ : Poly) + : diseq_split_cert p₁ p₂ p₃ → p₁.denote' ctx ≠ 0 → p₂.denote' ctx ≤ 0 ∨ p₃.denote' ctx ≤ 0 := by + simp [diseq_split_cert] + intro _ _; subst p₂ p₃; simp + generalize p₁.denote ctx = p + intro h; cases Int.lt_or_gt_of_ne h + next h => have := Int.add_one_le_of_lt h; rw [Int.add_comm]; simp [*] + next h => have := Int.add_one_le_of_lt (Int.neg_lt_neg h); simp at this; simp [*] + +theorem diseq_split_resolve (ctx : Context) (p₁ p₂ p₃ : Poly) + : diseq_split_cert p₁ p₂ p₃ → p₁.denote' ctx ≠ 0 → ¬p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by + intro h₁ h₂ h₃ + exact (diseq_split ctx p₁ p₂ p₃ h₁ h₂).resolve_left h₃ + def OrOver (n : Nat) (p : Nat → Prop) : Prop := match n with | 0 => False diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean index 45ae6df35f..0a2505b66a 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean @@ -51,5 +51,7 @@ builtin_initialize registerTraceClass `grind.cutsat.diseq.trivial (inherited := builtin_initialize registerTraceClass `grind.debug.cutsat.eq builtin_initialize registerTraceClass `grind.debug.cutsat.diseq +builtin_initialize registerTraceClass `grind.debug.cutsat.diseq.split +builtin_initialize registerTraceClass `grind.debug.cutsat.backtrack end Lean diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean index 2cf9782e2c..8e3a61ddc4 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean @@ -96,6 +96,12 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := c'.caching do return mkApp7 (mkConst ``Int.Linear.le_of_le_diseq) (← getContext) (toExpr c₁.p) (toExpr c₂.p) (toExpr c'.p) reflBoolTrue (← c₁.toExprProof) (← c₂.toExprProof) + | .ofDiseqSplit c₁ fvarId h _ => + let p₂ := c₁.p.addConst 1 + let hFalse ← h.toExprProofCore + let hNot := mkLambda `h .default (mkIntLE (← p₂.denoteExpr') (mkIntLit 0)) (hFalse.abstract #[mkFVar fvarId]) + return mkApp7 (mkConst ``Int.Linear.diseq_split_resolve) + (← getContext) (toExpr c₁.p) (toExpr p₂) (toExpr c'.p) reflBoolTrue (← c₁.toExprProof) hNot partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := c'.caching do match c'.h with @@ -108,32 +114,36 @@ partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := c'.caching | .divCoeffs c => let k := c.p.gcdCoeffs c.p.getConst return mkApp6 (mkConst ``Int.Linear.diseq_coeff) (← getContext) (toExpr c.p) (toExpr c'.p) (toExpr k) reflBoolTrue (← c.toExprProof) + | .neg c => + return mkApp5 (mkConst ``Int.Linear.diseq_neg) (← getContext) (toExpr c.p) (toExpr c'.p) reflBoolTrue (← c.toExprProof) | .subst x c₁ c₂ => return mkApp8 (mkConst ``Int.Linear.eq_diseq_subst) (← getContext) (toExpr x) (toExpr c₁.p) (toExpr c₂.p) (toExpr c'.p) reflBoolTrue (← c₁.toExprProof) (← c₂.toExprProof) +partial def UnsatProof.toExprProofCore (h : UnsatProof) : ProofM Expr := do + match h with + | .le c => + trace[grind.cutsat.le.unsat] "{← c.pp}" + return mkApp4 (mkConst ``Int.Linear.le_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof) + | .dvd c => + trace[grind.cutsat.dvd.unsat] "{← c.pp}" + return mkApp5 (mkConst ``Int.Linear.dvd_unsat) (← getContext) (toExpr c.d) (toExpr c.p) reflBoolTrue (← c.toExprProof) + | .eq c => + trace[grind.cutsat.eq.unsat] "{← c.pp}" + if c.p.isUnsatEq then + return mkApp4 (mkConst ``Int.Linear.eq_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof) + else + let k := c.p.gcdCoeffs' + return mkApp5 (mkConst ``Int.Linear.eq_unsat_coeff) (← getContext) (toExpr c.p) (toExpr (Int.ofNat k)) reflBoolTrue (← c.toExprProof) + | .diseq c => + trace[grind.cutsat.diseq.unsat] "{← c.pp}" + return mkApp4 (mkConst ``Int.Linear.diseq_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof) + end def UnsatProof.toExprProof (h : UnsatProof) : GoalM Expr := do - withProofContext do - match h with - | .le c => - trace[grind.cutsat.le.unsat] "{← c.pp}" - return mkApp4 (mkConst ``Int.Linear.le_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof) - | .dvd c => - trace[grind.cutsat.dvd.unsat] "{← c.pp}" - return mkApp5 (mkConst ``Int.Linear.dvd_unsat) (← getContext) (toExpr c.d) (toExpr c.p) reflBoolTrue (← c.toExprProof) - | .eq c => - trace[grind.cutsat.eq.unsat] "{← c.pp}" - if c.p.isUnsatEq then - return mkApp4 (mkConst ``Int.Linear.eq_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof) - else - let k := c.p.gcdCoeffs' - return mkApp5 (mkConst ``Int.Linear.eq_unsat_coeff) (← getContext) (toExpr c.p) (toExpr (Int.ofNat k)) reflBoolTrue (← c.toExprProof) - | .diseq c => - trace[grind.cutsat.diseq.unsat] "{← c.pp}" - return mkApp4 (mkConst ``Int.Linear.diseq_unsat) (← getContext) (toExpr c.p) reflBoolTrue (← c.toExprProof) + withProofContext do h.toExprProofCore def setInconsistent (h : UnsatProof) : GoalM Unit := do if (← get').caseSplits then @@ -143,4 +153,70 @@ def setInconsistent (h : UnsatProof) : GoalM Unit := do let h ← h.toExprProof closeGoal h +/-! +A cutsat proof may depend on decision variables. +We collect them and perform non chronological backtracking. +-/ + +structure CollectDecVars.State where + visited : Std.HashSet Nat := {} + found : FVarIdSet := {} + +abbrev CollectDecVarsM := ReaderT FVarIdSet (StateM CollectDecVars.State) + +private def alreadyVisited (id : Nat) : CollectDecVarsM Bool := do + if (← get).visited.contains id then return true + modify fun s => { s with visited := s.visited.insert id } + return false + +private def markAsFound (fvarId : FVarId) : CollectDecVarsM Unit := do + modify fun s => { s with found := s.found.insert fvarId } + +private def collectExpr (e : Expr) : CollectDecVarsM Unit := do + let .fvar fvarId := e | return () + if (← read).contains fvarId then + markAsFound fvarId + +mutual +partial def EqCnstr.collectDecVars (c' : EqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c'.id) do + match c'.h with + | .expr h => collectExpr h + | .core .. => return () -- Equalities coming from the core never contain cutsat decision variables + | .norm c | .divCoeffs c => c.collectDecVars + | .subst _ c₁ c₂ | .ofLeGe c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars + +partial def DvdCnstr.collectDecVars (c' : DvdCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c'.id) do + match c'.h with + | .expr h => collectExpr h + | .norm c | .elim c | .divCoeffs c | .ofEq _ c => c.collectDecVars + | .solveCombine c₁ c₂ | .solveElim c₁ c₂ | .subst _ c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars + +partial def LeCnstr.collectDecVars (c' : LeCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c'.id) do + match c'.h with + | .expr h => collectExpr h + | .notExpr .. => return () -- This kind of proof is used for connecting with the `grind` core. + | .norm c | .divCoeffs c => c.collectDecVars + | .combine c₁ c₂ | .subst _ c₁ c₂ | .ofLeDiseq c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars + | .ofDiseqSplit _ _ _ decVars => + -- Recall that we cache the decision variables used in this kind of proof + for fvar in decVars do + markAsFound fvar + +partial def DiseqCnstr.collectDecVars (c' : DiseqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c'.id) do + match c'.h with + | .expr h => collectExpr h + | .core .. => return () -- Disequalities coming from the core never contain cutsat decision variables + | .norm c | .divCoeffs c | .neg c => c.collectDecVars + | .subst _ c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars + +end + +def UnsatProof.collectDecVars (h : UnsatProof) : CollectDecVarsM Unit := do + match h with + | .le c | .dvd c | .eq c | .diseq c => c.collectDecVars + +abbrev CollectDecVarsM.run (x : CollectDecVarsM Unit) (decVars : FVarIdSet) : FVarIdSet := + let (_, s) := x decVars |>.run {} + s.found + end Lean.Meta.Grind.Arith.Cutsat diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Search.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Search.lean index ee705a4896..7cdb7f9283 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Search.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Search.lean @@ -8,6 +8,7 @@ import Lean.Meta.Tactic.Grind.Arith.Cutsat.Var import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util import Lean.Meta.Tactic.Grind.Arith.Cutsat.DvdCnstr import Lean.Meta.Tactic.Grind.Arith.Cutsat.LeCnstr +import Lean.Meta.Tactic.Grind.Arith.Cutsat.EqCnstr import Lean.Meta.Tactic.Grind.Arith.Cutsat.SearchM import Lean.Meta.Tactic.Grind.Arith.Cutsat.Model @@ -265,8 +266,28 @@ def resolveCooperDvd (c₁ c₂ : LeCnstr) (c : DvdCnstr) : GoalM Unit := do def resolveCooperDiseq (c₁ : DiseqCnstr) (c₂ : LeCnstr) (_c? : Option DvdCnstr) : GoalM Unit := do throwError "Cooper-diseq NIY {← c₁.pp} {← c₂.pp}" -def resolveRatDiseq (c₁ : LeCnstr) (c : DiseqCnstr) : GoalM Unit := do - throwError "diseq NIY {← c₁.pp} {← c.pp}" +/-- +Given `c₁` of the form `-a₁*x + p₁ ≤ 0`, and `c` of the form `b*x + p ≠ 0`, +splits `c` and resolve with `c₁`. +Recall that a disequality +-/ +def resolveRatDiseq (c₁ : LeCnstr) (c : DiseqCnstr) : SearchM Unit := do + let c ← if c.p.leadCoeff < 0 then + mkDiseqCnstr (c.p.mul (-1)) (.neg c) + else + pure c + let fvarId ← if let some fvarId := (← get').diseqSplits.find? c.p then + trace[grind.debug.cutsat.diseq.split] "{← c.pp}, reusing {fvarId.name}" + pure fvarId + else + let fvarId ← mkCase (.diseq c) + trace[grind.debug.cutsat.diseq.split] "{← c.pp}, {fvarId.name}" + modify' fun s => { s with diseqSplits := s.diseqSplits.insert c.p fvarId } + pure fvarId + let p₂ := c.p.addConst 1 + let c₂ ← mkLeCnstr p₂ (.expr (mkFVar fvarId)) + let b ← resolveRealLowerUpperConflict c₁ c₂ + assert! b def processVar (x : Var) : SearchM Unit := do if (← eliminated x) then @@ -334,20 +355,47 @@ def processVar (x : Var) : SearchM Unit := do def hasAssignment : GoalM Bool := do return (← get').vars.size == (← get').assignment.size -private def isDone : GoalM Bool := do - if (← hasAssignment) then +private def findCase (decVars : FVarIdSet) : SearchM Case := do + repeat + let numCases := (← get).cases.size + assert! numCases > 0 + let case := (← get).cases[numCases-1]! + modify fun s => { s with cases := s.cases.pop } + if decVars.contains case.fvarId then + return case + -- Conflict does not depend on this case. + trace[grind.debug.cutsat.backtrack] "skipping {case.fvarId.name}" + unreachable! + +def resolveConflict (h : UnsatProof) : SearchM Bool := do + let decVars := h.collectDecVars.run (← get).decVars + if decVars.isEmpty then + closeGoal (← h.toExprProof) + return false + let c ← findCase decVars + modify' fun _ => c.saved + match c.kind with + | .diseq c₁ => + let decVars := decVars.erase c.fvarId |>.toArray + let p' := c₁.p.mul (-1) |>.addConst 1 + let c' ← mkLeCnstr p' (.ofDiseqSplit c₁ c.fvarId h decVars) + trace[grind.debug.cutsat.backtrack] "resolved diseq split: {← c'.pp}" + c'.assert return true - if (← inconsistent) then - return true - return false + | _ => throwError "NIY resolve conflict" /-- Search for an assignment/model for the linear constraints. -/ def searchAssigmentMain : SearchM Unit := do repeat - if (← isDone) then + if (← hasAssignment) then return () + if (← isInconsistent) then + -- `grind` state is inconsistent + return () + if let some c := (← get').conflict? then + unless (← resolveConflict c) do + return () let x : Var := (← get').assignment.size - -- TODO: resolve unsat conflicts processVar x def traceModel : GoalM Unit := do diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/SearchM.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/SearchM.lean index d18cb7641a..b04e5b61b4 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/SearchM.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/SearchM.lean @@ -14,7 +14,7 @@ In principle, we only need to support two kinds of case split. - Cooper-Left, but we have 4 different variants of this one. -/ inductive CaseKind where - | diseq + | diseq (d : DiseqCnstr) | copperLeft | copperDvdLeft | cooperRight @@ -30,11 +30,6 @@ structure Case where -/ fvarId : FVarId /-- - Decision variable as a Lean type. We use it to construct - the actual proof term. - -/ - type : Expr - /-- Snapshot of the cutsat state for backtracking purposes. We do not use a trail stack. -/ @@ -75,4 +70,14 @@ def isApprox : SearchM Bool := def setImprecise : SearchM Unit := do modify fun s => { s with precise := false } +def mkCase (kind : CaseKind) : SearchM FVarId := do + let fvarId ← mkFreshFVarId + let saved ← get' + modify fun s => { s with + cases := s.cases.push { saved, fvarId, kind } + decVars := s.decVars.insert fvarId + } + modify' fun s => { s with caseSplits := true } + return fvarId + end Lean.Meta.Grind.Arith.Cutsat diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean index e318047eb3..dd86ffab16 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean @@ -15,6 +15,8 @@ namespace Lean.Meta.Grind.Arith.Cutsat export Int.Linear (Var Poly) export Std.Internal (Rat) +deriving instance Hashable for Poly + /-! This module implements a model-based decision procedure for linear integer arithmetic, inspired by Section 4 of "Cutting to the Chase: Solving Linear Integer Arithmetic". @@ -115,6 +117,7 @@ inductive LeCnstrProof where | combine (c₁ c₂ : LeCnstr) | subst (x : Var) (c₁ : EqCnstr) (c₂ : LeCnstr) | ofLeDiseq (c₁ : LeCnstr) (c₂ : DiseqCnstr) + | ofDiseqSplit (c₁ : DiseqCnstr) (decVar : FVarId) (h : UnsatProof) (decVars : Array FVarId) -- TODO: missing constructors /-- A disequality constraint and its justification/proof. -/ @@ -128,13 +131,9 @@ inductive DiseqCnstrProof where | core (p₁ p₂ : Poly) (h : Expr) | norm (c : DiseqCnstr) | divCoeffs (c : DiseqCnstr) + | neg (c : DiseqCnstr) | subst (x : Var) (c₁ : EqCnstr) (c₂ : DiseqCnstr) -end - -instance : Inhabited DvdCnstr where - default := { d := 0, p := .num 0, h := .expr default, id := 0 } - /-- A proof of `False`. Remark: We will later add support for a backtraking search inside of cutsat. @@ -145,6 +144,11 @@ inductive UnsatProof where | eq (c : EqCnstr) | diseq (c : DiseqCnstr) +end + +instance : Inhabited DvdCnstr where + default := { d := 0, p := .num 0, h := .expr default, id := 0 } + abbrev VarSet := RBTree Var compare /-- State of the cutsat procedure. -/ @@ -209,6 +213,12 @@ structure State where can convert `UnsatProof` into a Lean term and close the current `grind` goal. -/ conflict? : Option UnsatProof := none + /-- + Cache decision variables used when splitting on disequalities. + This is necessary because the same disequality may be in different conflicts. + -/ + diseqSplits : PHashMap Poly FVarId := {} + /- TODO: Model-based theory combination. -/ diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Util.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Util.lean index f20f0dfb49..ebb120beca 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Util.lean @@ -46,9 +46,8 @@ def get' : GoalM State := do /-- Returns `true` if the cutsat state is inconsistent. -/ def inconsistent : GoalM Bool := do - -- TODO: we will have a nested backtracking search in cutsat - -- and this function will have to be refined. - isInconsistent + if (← isInconsistent) then return true + return (← get').conflict?.isSome def getVars : GoalM (PArray Expr) := return (← get').vars diff --git a/tests/lean/run/grind_cutsat_diseq_3.lean b/tests/lean/run/grind_cutsat_diseq_3.lean new file mode 100644 index 0000000000..282fce0c91 --- /dev/null +++ b/tests/lean/run/grind_cutsat_diseq_3.lean @@ -0,0 +1,8 @@ +set_option grind.warning false +set_option grind.debug true +open Int.Linear + +theorem ex₁ (a b c : Int) : c ≥ 0 → b ≥ 0 → 1 ≤ a + c → a + b ≤ 1 → a ≠ 1 → c ≤ a → False := by + grind + +#print ex₁ diff --git a/tests/lean/run/grind_cutsat_le_2.lean b/tests/lean/run/grind_cutsat_le_2.lean index 65984087f9..c500666d38 100644 --- a/tests/lean/run/grind_cutsat_le_2.lean +++ b/tests/lean/run/grind_cutsat_le_2.lean @@ -1,5 +1,5 @@ set_option grind.warning false --- set_option grind.debug true +set_option grind.debug true open Int.Linear example (a b c d e : Int) :