From f4afcfc923a85492ee763790a035fbc3c5406b2f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 14 Feb 2025 20:20:40 -0800 Subject: [PATCH] feat: divisibility constraint normalizer (#7092) This PR implements divisibility constraint normalization in `simp +arith`. --- src/Init/Data/Int/DivModLemmas.lean | 6 ++-- src/Init/Data/Int/Linear.lean | 18 ++++++---- src/Init/Data/Nat/Dvd.lean | 4 +-- src/Init/Omega/IntList.lean | 2 +- src/Lean/Expr.lean | 18 +++++++--- src/Lean/Meta/IntInstTesters.lean | 3 ++ src/Lean/Meta/Tactic/LinearArith/Basic.lean | 3 ++ .../Meta/Tactic/LinearArith/Int/Basic.lean | 33 +++++++++++++++++++ .../Meta/Tactic/LinearArith/Int/Simp.lean | 20 +++++++++++ src/Lean/Meta/Tactic/LinearArith/Simp.lean | 2 +- src/Lean/Meta/Tactic/Simp/Rewrite.lean | 5 +++ tests/lean/run/simp_int_arith.lean | 27 +++++++++++++++ 12 files changed, 122 insertions(+), 19 deletions(-) diff --git a/src/Init/Data/Int/DivModLemmas.lean b/src/Init/Data/Int/DivModLemmas.lean index 6c880e6ec8..d1003f1559 100644 --- a/src/Init/Data/Int/DivModLemmas.lean +++ b/src/Init/Data/Int/DivModLemmas.lean @@ -22,11 +22,11 @@ namespace Int protected theorem dvd_def (a b : Int) : (a ∣ b) = Exists (fun c => b = a * c) := rfl -protected theorem dvd_zero (n : Int) : n ∣ 0 := ⟨0, (Int.mul_zero _).symm⟩ +@[simp] protected theorem dvd_zero (n : Int) : n ∣ 0 := ⟨0, (Int.mul_zero _).symm⟩ -protected theorem dvd_refl (n : Int) : n ∣ n := ⟨1, (Int.mul_one _).symm⟩ +@[simp] protected theorem dvd_refl (n : Int) : n ∣ n := ⟨1, (Int.mul_one _).symm⟩ -protected theorem one_dvd (n : Int) : 1 ∣ n := ⟨n, (Int.one_mul n).symm⟩ +@[simp] protected theorem one_dvd (n : Int) : 1 ∣ n := ⟨n, (Int.one_mul n).symm⟩ protected theorem dvd_trans : ∀ {a b c : Int}, a ∣ b → b ∣ c → a ∣ c | _, _, _, ⟨d, rfl⟩, ⟨e, rfl⟩ => Exists.intro (d * e) (by rw [Int.mul_assoc]) diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean index d0eb60268c..a42bf1034a 100644 --- a/src/Init/Data/Int/Linear.lean +++ b/src/Init/Data/Int/Linear.lean @@ -566,17 +566,20 @@ def Poly.mul (p : Poly) (k : Int) : Poly := rw [Int.mul_assoc] structure DvdPolyCnstr where - a : Int + k : Int p : Poly def DvdPolyCnstr.denote (ctx : Context) (c : DvdPolyCnstr) : Prop := - c.a ∣ c.p.denote ctx + c.k ∣ c.p.denote ctx def DvdPolyCnstr.isUnsat (c : DvdPolyCnstr) : Bool := - c.p.getConst % c.p.gcdCoeffs c.a != 0 + c.p.getConst % c.p.gcdCoeffs c.k != 0 def DvdPolyCnstr.isEqv (c₁ c₂ : DvdPolyCnstr) (k : Int) : Bool := - k != 0 && c₁.a == k*c₂.a && c₁.p == c₂.p.mul k + k != 0 && c₁.k == k*c₂.k && c₁.p == c₂.p.mul k + +def DvdPolyCnstr.div (k' : Int) : DvdPolyCnstr → DvdPolyCnstr + | { k, p } => { k := k / k', p := p.div k' } private theorem not_dvd_of_not_mod_zero {a b : Int} (h : ¬ b % a = 0) : ¬ a ∣ b := by intro h; have := Int.emod_eq_zero_of_dvd h; contradiction @@ -611,14 +614,15 @@ def DvdPolyCnstr.eq_false_of_isUnsat (ctx : Context) (c : DvdPolyCnstr) : c.isUn simp [denote, *] structure DvdCnstr where - a : Int + k : Int e : Expr + deriving BEq def DvdCnstr.denote (ctx : Context) (c : DvdCnstr) : Prop := - c.a ∣ c.e.denote ctx + c.k ∣ c.e.denote ctx def DvdCnstr.toPoly (c : DvdCnstr) : DvdPolyCnstr := - { a := c.a, p := c.e.toPoly } + { k := c.k, p := c.e.toPoly } @[simp] theorem DvdCnstr.denote_toPoly_eq (ctx : Context) (c : DvdCnstr) : c.denote ctx = c.toPoly.denote ctx := by simp [toPoly, denote, DvdPolyCnstr.denote] diff --git a/src/Init/Data/Nat/Dvd.lean b/src/Init/Data/Nat/Dvd.lean index c29559b589..1c09b35766 100644 --- a/src/Init/Data/Nat/Dvd.lean +++ b/src/Init/Data/Nat/Dvd.lean @@ -9,9 +9,9 @@ import Init.Meta namespace Nat -protected theorem dvd_refl (a : Nat) : a ∣ a := ⟨1, by simp⟩ +@[simp] protected theorem dvd_refl (a : Nat) : a ∣ a := ⟨1, by simp⟩ -protected theorem dvd_zero (a : Nat) : a ∣ 0 := ⟨0, by simp⟩ +@[simp] protected theorem dvd_zero (a : Nat) : a ∣ 0 := ⟨0, by simp⟩ protected theorem dvd_mul_left (a b : Nat) : a ∣ b * a := ⟨b, Nat.mul_comm b a⟩ protected theorem dvd_mul_right (a b : Nat) : a ∣ a * b := ⟨b, rfl⟩ diff --git a/src/Init/Omega/IntList.lean b/src/Init/Omega/IntList.lean index 8726d35632..4501c6021b 100644 --- a/src/Init/Omega/IntList.lean +++ b/src/Init/Omega/IntList.lean @@ -303,7 +303,7 @@ theorem dvd_gcd (xs : IntList) (c : Nat) (w : ∀ {a : Int}, a ∈ xs → (c : I c ∣ xs.gcd := by simp only [Int.ofNat_dvd_left] at w induction xs with - | nil => have := Nat.dvd_zero c; simp at this; exact this + | nil => have := Nat.dvd_zero c; simp | cons x xs ih => simp apply Nat.dvd_gcd diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 93117c4139..57b3599e9b 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -2245,20 +2245,28 @@ def mkIntMul (a b : Expr) : Expr := private def intLEPred : Expr := mkApp2 (mkConst ``LE.le [0]) Int.mkType Int.mkInstLE -/-- Given `a b : Int`, return `a ≤ b` -/ +/-- Given `a b : Int`, returns `a ≤ b` -/ def mkIntLE (a b : Expr) : Expr := mkApp2 intLEPred a b private def intEqPred : Expr := mkApp (mkConst ``Eq [1]) Int.mkType -/-- Given `a b : Int`, return `a = b` -/ +/-- Given `a b : Int`, returns `a = b` -/ def mkIntEq (a b : Expr) : Expr := mkApp2 intEqPred a b -def mkIntLit (n : Nat) : Expr := - let r := mkRawNatLit n - mkApp3 (mkConst ``OfNat.ofNat [levelZero]) Int.mkType r (mkApp (mkConst ``instOfNat) r) +/-- Given `a b : Int`, returns `a ∣ b` -/ +def mkIntDvd (a b : Expr) : Expr := + mkApp4 (mkConst ``Dvd.dvd [0]) Int.mkType (mkConst ``Int.instDvd) a b + +def mkIntLit (n : Int) : Expr := + let r := mkRawNatLit n.natAbs + let r := mkApp3 (mkConst ``OfNat.ofNat [levelZero]) Int.mkType r (mkApp (mkConst ``instOfNat) r) + if n < 0 then + mkIntNeg r + else + r def reflBoolTrue : Expr := mkApp2 (mkConst ``Eq.refl [levelOne]) (mkConst ``Bool) (mkConst ``Bool.true) diff --git a/src/Lean/Meta/IntInstTesters.lean b/src/Lean/Meta/IntInstTesters.lean index 4dfe86c12d..ef54eaa316 100644 --- a/src/Lean/Meta/IntInstTesters.lean +++ b/src/Lean/Meta/IntInstTesters.lean @@ -32,6 +32,9 @@ def isInstDivInt (e : Expr) : MetaM Bool := do def isInstModInt (e : Expr) : MetaM Bool := do let_expr Int.instMod ← e | return false return true +def isInstDvdInt (e : Expr) : MetaM Bool := do + let_expr Int.instDvd ← e | return false + return true def isInstHAddInt (e : Expr) : MetaM Bool := do let_expr instHAdd _ i ← e | return false isInstAddInt i diff --git a/src/Lean/Meta/Tactic/LinearArith/Basic.lean b/src/Lean/Meta/Tactic/LinearArith/Basic.lean index 3be0041b5c..80bd43146f 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Basic.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Basic.lean @@ -57,4 +57,7 @@ partial def isLinearCnstr (e : Expr) : Bool := else false +def isDvdCnstr (e : Expr) : Bool := + e.isAppOfArity ``Dvd.dvd 4 + end Lean.Meta.Linear diff --git a/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean b/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean index c7a7585075..5b2a2475f4 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean @@ -50,6 +50,12 @@ def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr → ExprCnstr | .eq a b => .eq (a.applyPerm perm) (b.applyPerm perm) | .le a b => .le (a.applyPerm perm) (b.applyPerm perm) +def DvdCnstr.applyPerm (perm : Lean.Perm) : DvdCnstr → DvdCnstr + | { k, e } => { k, e := e.applyPerm perm } + +def DvdPolyCnstr.toDvdCnstr : DvdPolyCnstr → DvdCnstr + | { k, p } => { k, e := p.toExpr } + end Int.Linear namespace Lean.Meta.Linear.Int @@ -62,6 +68,7 @@ deriving instance Repr for Int.Linear.PolyCnstr abbrev LinearExpr := Int.Linear.Expr abbrev LinearCnstr := Int.Linear.ExprCnstr abbrev PolyExpr := Int.Linear.Poly +abbrev DvdCnstr := Int.Linear.DvdCnstr def LinearExpr.toExpr (e : LinearExpr) : Expr := open Int.Linear.Expr in @@ -88,6 +95,13 @@ instance : ToExpr LinearCnstr where toExpr a := a.toExpr toTypeExpr := mkConst ``Int.Linear.ExprCnstr +protected def DvdCnstr.toExpr (c : DvdCnstr) : Expr := + mkApp2 (mkConst ``Int.Linear.DvdCnstr.mk) (toExpr c.k) (toExpr c.e) + +instance : ToExpr DvdCnstr where + toExpr a := a.toExpr + toTypeExpr := mkConst ``Int.Linear.DvdCnstr + open Int.Linear.Expr in def LinearExpr.toArith (ctx : Array Expr) (e : LinearExpr) : MetaM Expr := do match e with @@ -104,6 +118,9 @@ def LinearCnstr.toArith (ctx : Array Expr) (c : LinearCnstr) : MetaM Expr := do | .eq e₁ e₂ => return mkIntEq (← LinearExpr.toArith ctx e₁) (← LinearExpr.toArith ctx e₂) | .le e₁ e₂ => return mkIntLE (← LinearExpr.toArith ctx e₁) (← LinearExpr.toArith ctx e₂) +def DvdCnstr.toArith (ctx : Array Expr) (c : DvdCnstr) : MetaM Expr := do + return mkIntDvd (mkIntLit c.k) (← LinearExpr.toArith ctx c.e) + namespace ToLinear structure State where @@ -200,6 +217,12 @@ partial def toLinearCnstr? (e : Expr) : M (Option LinearCnstr) := OptionT.run do return .le (.add (← toLinearExpr b) (.num 1)) (← toLinearExpr a) | _ => failure +partial def toDvdCnstr? (e : Expr) : M (Option DvdCnstr) := OptionT.run do + let_expr Dvd.dvd _ inst k b ← e | failure + guard (← isInstDvdInt inst) + let some k ← getIntValue? k | failure + return { k, e := (← toLinearExpr b) } + def run (x : M α) : MetaM (α × Array Expr) := do let (a, s) ← x.run {} return (a, s.vars) @@ -225,6 +248,16 @@ def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do let c := c.applyPerm perm return some (c, atoms) +def toDvdCnstr? (e : Expr) : MetaM (Option (DvdCnstr × Array Expr)) := do + let (some c, atoms) ← ToLinear.run (ToLinear.toDvdCnstr? e) + | return none + if atoms.size <= 1 then + return some (c, atoms) + else + let (atoms, perm) := sortExprs atoms + let c := c.applyPerm perm + return some (c, atoms) + def toContextExpr (ctx : Array Expr) : Expr := if h : 0 < ctx.size then RArray.toExpr (mkConst ``Int) id (RArray.ofArray ctx h) diff --git a/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean index 830ee70ad8..0e21c6008d 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean @@ -126,6 +126,26 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do else simpCnstrPos? e +def simpDvdCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do + let some (c, atoms) ← toDvdCnstr? e | return none + if c.k == 0 then return none + withAbstractAtoms atoms ``Int fun atoms => do + let lhs ← c.toArith atoms + let c' := c.toPoly + let k := c'.p.gcdCoeffs c'.k + if c'.p.getConst % k == 0 then + let c' := c'.div k + let c' : DvdCnstr := c'.toDvdCnstr + if c == c' then + return none + let r ← c'.toArith atoms + let h := mkApp5 (mkConst ``Int.Linear.DvdCnstr.eq_of_isEqv) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr k) reflBoolTrue + return some (r, ← mkExpectedTypeHint h (← mkEq lhs r)) + else + let r := mkConst ``False + let h := mkApp3 (mkConst ``Int.Linear.DvdCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue + return some (r, ← mkExpectedTypeHint h (← mkEq lhs r)) + def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do let (e, atoms) ← toLinearExpr e let p := e.toPoly diff --git a/src/Lean/Meta/Tactic/LinearArith/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Simp.lean index cb548b2187..2e393408f5 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Simp.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Simp.lean @@ -13,7 +13,7 @@ namespace Lean.Meta.Linear def parentIsTarget (parent? : Option Expr) : Bool := match parent? with | none => false - | some parent => isLinearTerm parent || isLinearCnstr parent + | some parent => isLinearTerm parent || isLinearCnstr parent || isDvdCnstr parent def simp? (e : Expr) (parent? : Option Expr) : MetaM (Option (Expr × Expr)) := do -- TODO: add support for `Int` and arbitrary ordered comm rings diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index af4007ce2a..d795311ba0 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -300,6 +300,11 @@ def simpArith (e : Expr) : SimpM Step := do return .visit { expr := e', proof? := h } else return .continue + else if Linear.isDvdCnstr e then + if let some (e', h) ← Linear.Int.simpDvdCnstr? e then + return .visit { expr := e', proof? := h } + else + return .continue else return .continue diff --git a/tests/lean/run/simp_int_arith.lean b/tests/lean/run/simp_int_arith.lean index 8409f382ea..fbfbdd647c 100644 --- a/tests/lean/run/simp_int_arith.lean +++ b/tests/lean/run/simp_int_arith.lean @@ -265,3 +265,30 @@ example (x y : Int) : (2*x + y + y ≤ 3) ↔ (y + x ≤ 1) := by example (f : Int → Int) (x y : Int) : f (2*x + y) = f (y + x + x) := by simp +arith + +example (a b : Int) : ¬ 2 ∣ 2*a + 4*b + 1 := by + simp +arith + +example (a b : Int) : ¬ 2 ∣ a + 3*b + 1 + b + a := by + simp +arith + +example (a b : Int) : ¬ 2 ∣ a + 3*b + 1 + b + 5*a := by + simp +arith + +example (a b : Int) : 2 ∣ 4*a + 6*b + 8 := by + simp +arith + +example (a b : Int) : 2 ∣ 2*(a + a) + (3+3)*(b + b) + 8 := by + simp +arith + +example (a : Int) : 16 ∣ 4*a + 32 ↔ 4 ∣ a + 8 := by + simp +arith + +example (a : Int) : 3 ∣ a + a + 1 + a + 1 + a ↔ 3 ∣ 4*a + 2 := by + simp +arith + +example (a : Int) : 2+1 ∣ a + a + 1 - a + 1 + a ↔ 3 ∣ 2*a + 2 := by + simp +arith + +example (a b : Int) : 6 ∣ a + 21 - a + 3*a + 6*b + 12 ↔ 2 ∣ a + 2*b + 11 := by + simp +arith