diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean index 047f470cd4..e16e6ae8a3 100644 --- a/src/Init/Data/Int/Linear.lean +++ b/src/Init/Data/Int/Linear.lean @@ -2167,6 +2167,13 @@ theorem of_var_eq_mul (ctx : Context) (x : Var) (k : Int) (y : Var) (p : Poly) : simp [of_var_eq_mul_cert]; intro _ h; subst p; simp [h] rw [Int.neg_mul, ← Int.sub_eq_add_neg, Int.sub_self] +@[expose] noncomputable def of_var_eq_var_cert (x : Var) (y : Var) (p : Poly) : Bool := + p.beq' (.add 1 x (.add (-1) y (.num 0))) + +theorem of_var_eq_var (ctx : Context) (x : Var) (y : Var) (p : Poly) : of_var_eq_var_cert x y p → x.denote ctx = y.denote ctx → p.denote' ctx = 0 := by + simp [of_var_eq_var_cert]; intro _ h; subst p; simp [h] + rw [← Int.sub_eq_add_neg, Int.sub_self] + @[expose] noncomputable def of_var_eq_cert (x : Var) (k : Int) (p : Poly) : Bool := p.beq' (.add 1 x (.num (-k))) @@ -2183,6 +2190,9 @@ theorem mul_eq_kxk (a b k₁ c k₂ k : Int) (h₁ : a = k₁*c) (h₂ : b = k theorem mul_eq_zero_left (a b : Int) (h : a = 0) : a*b = 0 := by simp [*] theorem mul_eq_zero_right (a b : Int) (h : b = 0) : a*b = 0 := by simp [*] +theorem div_eq (a b k : Int) (h : b = k) : a / b = a / k := by simp [*] +theorem mod_eq (a b k : Int) (h : b = k) : a % b = a % k := by simp [*] + end Int.Linear theorem Int.not_le_eq (a b : Int) : (¬a ≤ b) = (b + 1 ≤ a) := by diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean index 6780c4a040..3ffbb3da72 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean @@ -220,11 +220,27 @@ where goVar e private def propagateNonlinearDiv (x : Var) : GoalM Bool := do - trace[Meta.debug] "{← getVar x}" -- TODO + let e ← getVar x + let_expr HDiv.hDiv _ _ _ i a b := e | return false + unless (← isInstHDivInt i) do return false + let some (k, c) ← isExprEqConst? b | return false + let div' ← shareCommon (mkIntDiv a (mkIntLit k)) + internalize div' (← getGeneration e) + let y ← mkVar div' + let c' := { p := .add 1 x (.add (-1) y (.num 0)), h := .div k y c : EqCnstr } + c'.assert return true private def propagateNonlinearMod (x : Var) : GoalM Bool := do - trace[Meta.debug] "{← getVar x}" -- TODO + let e ← getVar x + let_expr HMod.hMod _ _ _ i a b := e | return false + unless (← isInstHModInt i) do return false + let some (k, c) ← isExprEqConst? b | return false + let mod' ← shareCommon (mkIntMod a (mkIntLit k)) + internalize mod' (← getGeneration e) + let y ← mkVar mod' + let c' := { p := .add 1 x (.add (-1) y (.num 0)), h := .mod k y c : EqCnstr } + c'.assert return true @[export lean_cutsat_propagate_nonlinear] @@ -540,15 +556,19 @@ private def expandDivMod (a : Expr) (b : Int) : GoalM Unit := do private def propagateDiv (e : Expr) : GoalM Unit := do let_expr HDiv.hDiv _ _ _ inst a b ← e | return () if (← isInstHDivInt inst) then - let some b ← getIntValue? b | return () - -- Remark: we currently do not consider the case where `b` is in the equivalence class of a numeral. - expandDivMod a b + if let some b ← getIntValue? b then + expandDivMod a b + else + discard <| mkVar e + private def propagateMod (e : Expr) : GoalM Unit := do let_expr HMod.hMod _ _ _ inst a b ← e | return () if (← isInstHModInt inst) then - let some b ← getIntValue? b | return () - expandDivMod a b + if let some b ← getIntValue? b then + expandDivMod a b + else + discard <| mkVar e private def propagateToInt (e : Expr) : GoalM Unit := do let_expr Grind.ToInt.toInt α _ _ a := e | return () diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean index cbecefba7f..63c791b462 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean @@ -281,6 +281,20 @@ partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do | .mul a? cs => let .add _ x _ := c'.p | c'.throwUnexpected mkMulEqProof x a? cs c' + | .div k y c => + let .add _ x _ := c'.p | c'.throwUnexpected + let_expr HDiv.hDiv _ _ _ _ a b := (← getCurrVars)[x]! | c'.throwUnexpected + let bVar ← getVarOf b + let h := mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl bVar) (toExpr k) (← mkPolyDecl c.p) eagerReflBoolTrue (← c.toExprProof) + let h := mkApp4 (mkConst ``Int.Linear.div_eq) a b (toExpr k) h + return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) (← getContext) (← mkVarDecl x) (← mkVarDecl y) (← mkPolyDecl c'.p) eagerReflBoolTrue h + | .mod k y c => + let .add _ x _ := c'.p | c'.throwUnexpected + let_expr HMod.hMod _ _ _ _ a b := (← getCurrVars)[x]! | c'.throwUnexpected + let bVar ← getVarOf b + let h := mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl bVar) (toExpr k) (← mkPolyDecl c.p) eagerReflBoolTrue (← c.toExprProof) + let h := mkApp4 (mkConst ``Int.Linear.mod_eq) a b (toExpr k) h + return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) (← getContext) (← mkVarDecl x) (← mkVarDecl y) (← mkPolyDecl c'.p) eagerReflBoolTrue h partial def mkMulEqProof (x : Var) (a? : Option Expr) (cs : Array (Expr × Int × EqCnstr)) (c' : EqCnstr) : ProofM Expr := do let h ← go (← getCurrVars)[x]! @@ -580,7 +594,7 @@ partial def EqCnstr.collectDecVars (c' : EqCnstr) : CollectDecVarsM Unit := do u match c'.h with | .core0 .. | .core .. | .defn .. | .defnNat .. | .defnCommRing .. | .defnNatCommRing .. | .coreToInt .. => return () -- Equalities coming from the core never contain cutsat decision variables - | .commRingNorm c .. | .reorder c | .norm c | .divCoeffs c => c.collectDecVars + | .commRingNorm c .. | .reorder c | .norm c | .divCoeffs c | .div _ _ c | .mod _ _ c => c.collectDecVars | .subst _ c₁ c₂ | .ofLeGe c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars | .mul _ cs => cs.forM fun (_, _, c) => c.collectDecVars diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean index e2427a49d4..a335ba051e 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean @@ -103,6 +103,8 @@ inductive EqCnstrProof where | defnCommRing (e : Expr) (p : Poly) (re : CommRing.RingExpr) (rp : CommRing.Poly) (p' : Poly) | defnNatCommRing (h : Expr) (x : Var) (e' : Int.Linear.Expr) (p : Poly) (re : CommRing.RingExpr) (rp : CommRing.Poly) (p' : Poly) | mul (a? : Option Expr) (cs : Array (Expr × Int × EqCnstr)) + | div (k : Int) (y : Var) (c : EqCnstr) + | mod (k : Int) (y : Var) (c : EqCnstr) /-- A divisibility constraint and its justification/proof. -/ structure DvdCnstr where diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Var.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Var.lean index 9a614a4b4b..22dc6ebfaf 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Var.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Var.lean @@ -39,8 +39,8 @@ private def registerNonlinearOcc (arg : Expr) (x : Var) : GoalM Unit := do private partial def registerNonlinearOccsAt (e : Expr) (x : Var) : GoalM Unit := do match_expr e with | HMul.hMul _ _ _ _ a b => go a; go b - | HDiv.hDiv _ _ _ _ a b => registerNonlinearOcc a x; registerNonlinearOcc b x - | HMod.hMod _ _ _ _ a b => registerNonlinearOcc a x; registerNonlinearOcc b x + | HDiv.hDiv _ _ _ _ _ b => registerNonlinearOcc b x + | HMod.hMod _ _ _ _ _ b => registerNonlinearOcc b x | _ => return () where go (e : Expr) : GoalM Unit := do diff --git a/tests/lean/run/grind_linearize.lean b/tests/lean/run/grind_linearize.lean index 88b6842557..c71a3becbb 100644 --- a/tests/lean/run/grind_linearize.lean +++ b/tests/lean/run/grind_linearize.lean @@ -30,3 +30,36 @@ example (a : Nat) (ha : a < 8) (b : Nat) (hb : b = 2) : a * b < 8 * b := by example (a : Nat) (ha : a < 8) (b c : Nat) : 2 ≤ b → c = 1 → b ≤ c + 1 → a * b < 8 * b := by grind -ring + +example (h : s = 4) : 4 < s - 1 + (s - 1) * (s - 1 - 1) / 2 := by + grind + +example (a b : Int) : a / b = 0 → b = 2 → a = 0 ∨ a = 1 := by + grind + +example (a b : Int) : b = 2 → a / b = 0 → a = 0 ∨ a = 1 := by + grind + +example (a b : Int) : b > 0 → b = 2 → a / b = 0 → a = 0 ∨ a = 1 := by + grind + +example (a b : Nat) : b > 0 → b = 2 → a / b = 0 → a = 0 ∨ a = 1 := by + grind + +example (a b c : Int) : a % b = 1 → b = 2 → a = 2 * c → False := by + grind + +example (a b c : Int) : b = 2 → a % b = 1 → a = 2 * c → False := by + grind + +example (a b c : Int) : b > 0 → b = 2 → a % b = 1 → a = 2 * c → False := by + grind + +example (a b c : Nat) : b > 0 → b = 2 → a % b = 1 → a = 2 * c → False := by + grind + +example (a b c d : Nat) : b > 0 → d = 1 → b = d + 1 → a % b = 1 → a = 2 * c → False := by + grind + +example (a b c d : Nat) : b > 1 → d = 1 → b ≤ d + 1 → a % b = 1 → a = 2 * c → False := by + grind