From 45affb5e09c7fef88f2465ab96190668bf5d9d3b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 20 Aug 2025 19:59:52 -0700 Subject: [PATCH] fix: missing nonlinear `/` and `%` in `grind cutsat` (#10020) This PR fixes a missing case for PR #10010. --- src/Init/Data/Int/Linear.lean | 3 ++ .../Tactic/Grind/Arith/Cutsat/EqCnstr.lean | 22 +++++++++----- .../Meta/Tactic/Grind/Arith/Cutsat/Proof.lean | 26 ++++++++++++---- .../Meta/Tactic/Grind/Arith/Cutsat/Types.lean | 14 +++++++-- tests/lean/run/grind_linearize.lean | 30 +++++++++++++++++++ 5 files changed, 79 insertions(+), 16 deletions(-) diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean index e16e6ae8a3..aca2c7fd40 100644 --- a/src/Init/Data/Int/Linear.lean +++ b/src/Init/Data/Int/Linear.lean @@ -2193,6 +2193,9 @@ theorem mul_eq_zero_right (a b : Int) (h : b = 0) : a*b = 0 := by simp [*] theorem div_eq (a b k : Int) (h : b = k) : a / b = a / k := by simp [*] theorem mod_eq (a b k : Int) (h : b = k) : a % b = a % k := by simp [*] +theorem div_eq' (a b b' k : Int) (h₁ : b = b') (h₂ : k == a/b') : a / b = k := by simp_all +theorem mod_eq' (a b b' k : Int) (h₁ : b = b') (h₂ : k == a%b') : a % b = k := by simp_all + end Int.Linear theorem Int.not_le_eq (a b : Int) : (¬a ≤ b) = (b + 1 ≤ a) := by diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean index 3ffbb3da72..aca0a45ea5 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean @@ -224,10 +224,13 @@ private def propagateNonlinearDiv (x : Var) : GoalM Bool := do let_expr HDiv.hDiv _ _ _ i a b := e | return false unless (← isInstHDivInt i) do return false let some (k, c) ← isExprEqConst? b | return false - let div' ← shareCommon (mkIntDiv a (mkIntLit k)) - internalize div' (← getGeneration e) - let y ← mkVar div' - let c' := { p := .add 1 x (.add (-1) y (.num 0)), h := .div k y c : EqCnstr } + let c' ← if let some a ← getIntValue? a then + pure { p := .add 1 x (.num (-(a/k))), h := .div k none c : EqCnstr } + else + let div' ← shareCommon (mkIntDiv a (mkIntLit k)) + internalize div' (← getGeneration e) + let y ← mkVar div' + pure { p := .add 1 x (.add (-1) y (.num 0)), h := .div k (some y) c : EqCnstr } c'.assert return true @@ -236,10 +239,13 @@ private def propagateNonlinearMod (x : Var) : GoalM Bool := do let_expr HMod.hMod _ _ _ i a b := e | return false unless (← isInstHModInt i) do return false let some (k, c) ← isExprEqConst? b | return false - let mod' ← shareCommon (mkIntMod a (mkIntLit k)) - internalize mod' (← getGeneration e) - let y ← mkVar mod' - let c' := { p := .add 1 x (.add (-1) y (.num 0)), h := .mod k y c : EqCnstr } + let c' ← if let some a ← getIntValue? a then + pure { p := .add 1 x (.num (-(a%k))), h := .mod k none c : EqCnstr } + else + let mod' ← shareCommon (mkIntMod a (mkIntLit k)) + internalize mod' (← getGeneration e) + let y ← mkVar mod' + pure { p := .add 1 x (.add (-1) y (.num 0)), h := .mod k (some y) c : EqCnstr } c'.assert return true diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean index 63c791b462..91de39c1dd 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean @@ -281,20 +281,34 @@ partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do | .mul a? cs => let .add _ x _ := c'.p | c'.throwUnexpected mkMulEqProof x a? cs c' - | .div k y c => + | .div k y? c => let .add _ x _ := c'.p | c'.throwUnexpected let_expr HDiv.hDiv _ _ _ _ a b := (← getCurrVars)[x]! | c'.throwUnexpected let bVar ← getVarOf b let h := mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl bVar) (toExpr k) (← mkPolyDecl c.p) eagerReflBoolTrue (← c.toExprProof) - let h := mkApp4 (mkConst ``Int.Linear.div_eq) a b (toExpr k) h - return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) (← getContext) (← mkVarDecl x) (← mkVarDecl y) (← mkPolyDecl c'.p) eagerReflBoolTrue h - | .mod k y c => + if let some y := y? then + let h := mkApp4 (mkConst ``Int.Linear.div_eq) a b (toExpr k) h + return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) (← getContext) (← mkVarDecl x) (← mkVarDecl y) (← mkPolyDecl c'.p) eagerReflBoolTrue h + else + let b' := k + let some aVal ← getIntValue? a | unreachable! + let k := aVal / b' + let h := mkApp6 (mkConst ``Int.Linear.div_eq') a b (toExpr b') (toExpr k) h eagerReflBoolTrue + return mkApp6 (mkConst ``Int.Linear.of_var_eq) (← getContext) (← mkVarDecl x) (toExpr k) (← mkPolyDecl c'.p) eagerReflBoolTrue h + | .mod k y? c => let .add _ x _ := c'.p | c'.throwUnexpected let_expr HMod.hMod _ _ _ _ a b := (← getCurrVars)[x]! | c'.throwUnexpected let bVar ← getVarOf b let h := mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl bVar) (toExpr k) (← mkPolyDecl c.p) eagerReflBoolTrue (← c.toExprProof) - let h := mkApp4 (mkConst ``Int.Linear.mod_eq) a b (toExpr k) h - return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) (← getContext) (← mkVarDecl x) (← mkVarDecl y) (← mkPolyDecl c'.p) eagerReflBoolTrue h + if let some y := y? then + let h := mkApp4 (mkConst ``Int.Linear.mod_eq) a b (toExpr k) h + return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) (← getContext) (← mkVarDecl x) (← mkVarDecl y) (← mkPolyDecl c'.p) eagerReflBoolTrue h + else + let b' := k + let some aVal ← getIntValue? a | unreachable! + let k := aVal % b' + let h := mkApp6 (mkConst ``Int.Linear.mod_eq') a b (toExpr b') (toExpr k) h eagerReflBoolTrue + return mkApp6 (mkConst ``Int.Linear.of_var_eq) (← getContext) (← mkVarDecl x) (toExpr k) (← mkPolyDecl c'.p) eagerReflBoolTrue h partial def mkMulEqProof (x : Var) (a? : Option Expr) (cs : Array (Expr × Int × EqCnstr)) (c' : EqCnstr) : ProofM Expr := do let h ← go (← getCurrVars)[x]! diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean index a335ba051e..410e0bca65 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean @@ -103,8 +103,18 @@ inductive EqCnstrProof where | defnCommRing (e : Expr) (p : Poly) (re : CommRing.RingExpr) (rp : CommRing.Poly) (p' : Poly) | defnNatCommRing (h : Expr) (x : Var) (e' : Int.Linear.Expr) (p : Poly) (re : CommRing.RingExpr) (rp : CommRing.Poly) (p' : Poly) | mul (a? : Option Expr) (cs : Array (Expr × Int × EqCnstr)) - | div (k : Int) (y : Var) (c : EqCnstr) - | mod (k : Int) (y : Var) (c : EqCnstr) + | /-- + Linearization proof for `/` + - If `?y = some y`, then it is a proof for `a / b = y / k` where `c` is a proof that `b = k` + - If `?y = none`, then it is a proof for `a / b = a/k` where `c` is a proof that `b = k`. `a` is a numeral in this case. + -/ + div (k : Int) (y? : Option Var) (c : EqCnstr) + | /-- + Linearization proof for `%` + - If `?y = some y`, then it is a proof for `a % b = y%k` where `c` is a proof that `b = k` + - If `?y = none`, then it is a proof for `a % b = a%k` where `c` is a proof that `b = k`. `a` is a numeral in this case. + -/ + mod (k : Int) (y? : Option Var) (c : EqCnstr) /-- A divisibility constraint and its justification/proof. -/ structure DvdCnstr where diff --git a/tests/lean/run/grind_linearize.lean b/tests/lean/run/grind_linearize.lean index c71a3becbb..14e87814aa 100644 --- a/tests/lean/run/grind_linearize.lean +++ b/tests/lean/run/grind_linearize.lean @@ -63,3 +63,33 @@ example (a b c d : Nat) : b > 0 → d = 1 → b = d + 1 → a % b = 1 → a = 2 example (a b c d : Nat) : b > 1 → d = 1 → b ≤ d + 1 → a % b = 1 → a = 2 * c → False := by grind + +example (b : Int) : 4 % b = 1 → b = 2 → False := by + grind + +example (b : Int) : b = 2 → 4 % b = 1 → False := by + grind + +example (b : Nat) : 4 % b = 1 → b = 2 → False := by + grind + +example (b : Int) : 4 / b = 1 → b = 2 → False := by + grind + +example (b : Nat) : 4 / b = 1 → b = 2 → False := by + grind + +example (b : Int) : 4 % b = 1 → b ≤ 2 → 2 ≤ b → False := by + grind + +example (b : Int) : b ≤ 2 → 2 ≤ b → 4 % b = 1 → False := by + grind + +example (b : Nat) : 4 % b = 1 → b ≤ 2 → 2 ≤ b → False := by + grind + +example (b : Int) : 4 / b = 1 → b ≤ 2 → 2 ≤ b → False := by + grind + +example (b : Nat) : 4 / b = 1 → b ≤ 2 → 2 ≤ b → False := by + grind