From ae728d84f007f798663eee7976f32b88316a8aeb Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 4 Aug 2025 14:27:11 +0200 Subject: [PATCH] perf: proof terms for `grind ring` and `grind cutsat` (#9710) This PR improves some of the proof terms produced by `grind ring` and `grind cutsat`. --- src/Init/Data/Int/Basic.lean | 14 ++++++++++ src/Init/Data/Int/LemmasAux.lean | 12 +++++++++ src/Init/Data/Int/Linear.lean | 32 +++++++++++------------ src/Init/Grind/Ring/Poly.lean | 45 +++++++++++++++++--------------- 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/src/Init/Data/Int/Basic.lean b/src/Init/Data/Int/Basic.lean index cab7d0f497..dc3e241270 100644 --- a/src/Init/Data/Int/Basic.lean +++ b/src/Init/Data/Int/Basic.lean @@ -410,6 +410,20 @@ instance : Max Int := maxOfLe (fun a => Int.rec (fun b => Nat.beq a b) (fun _ => false) b) (fun a => Int.rec (fun _ => false) (fun b => Nat.beq a b) b) a +/-- `x ≤ y` for kernel reduction. -/ +@[expose] protected noncomputable def ble' (a b : Int) : Bool := + Int.rec + (fun a => Int.rec (fun b => Nat.ble a b) (fun _ => false) b) + (fun a => Int.rec (fun _ => true) (fun b => Nat.ble b a) b) + a + +/-- `x < y` for kernel reduction. -/ +@[expose] protected noncomputable def blt' (a b : Int) : Bool := + Int.rec + (fun a => Int.rec (fun b => Nat.blt a b) (fun _ => false) b) + (fun a => Int.rec (fun _ => true) (fun b => Nat.blt b a) b) + a + end Int /-- diff --git a/src/Init/Data/Int/LemmasAux.lean b/src/Init/Data/Int/LemmasAux.lean index 683ff45a3e..109941fa1c 100644 --- a/src/Init/Data/Int/LemmasAux.lean +++ b/src/Init/Data/Int/LemmasAux.lean @@ -69,6 +69,18 @@ theorem natCast_succ_pos (n : Nat) : 0 < (n.succ : Int) := natCast_pos.2 n.succ_ @[simp, norm_cast] theorem cast_id {n : Int} : Int.cast n = n := rfl +@[simp] theorem ble'_eq_true (a b : Int) : (Int.ble' a b = true) = (a ≤ b) := by + cases a <;> cases b <;> simp [Int.ble'] <;> omega + +@[simp] theorem blt'_eq_true (a b : Int) : (Int.blt' a b = true) = (a < b) := by + cases a <;> cases b <;> simp [Int.blt'] <;> omega + +@[simp] theorem ble'_eq_false (a b : Int) : (Int.ble' a b = false) = ¬(a ≤ b) := by + simp [← Bool.not_eq_true] + +@[simp] theorem blt'_eq_false (a b : Int) : (Int.blt' a b = false) = ¬ (a < b) := by + simp [← Bool.not_eq_true] + /-! ### toNat -/ @[simp] theorem toNat_sub' (a : Int) (b : Nat) : (a - b).toNat = a.toNat - b := by diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean index 0b919e50af..04ffa1bbf4 100644 --- a/src/Init/Data/Int/Linear.lean +++ b/src/Init/Data/Int/Linear.lean @@ -586,7 +586,7 @@ theorem norm_eq_coeff' (ctx : Context) (p p' : Poly) (k : Int) : p = p'.mul k @[expose] noncomputable def norm_eq_coeff_cert (lhs rhs : Expr) (p : Poly) (k : Int) : Bool := - (lhs.sub rhs).norm.beq' (p.mul_k k) |>.and' (k > 0) + (lhs.sub rhs).norm.beq' (p.mul_k k) |>.and' (Int.blt' 0 k) theorem norm_eq_coeff (ctx : Context) (lhs rhs : Expr) (p : Poly) (k : Int) : norm_eq_coeff_cert lhs rhs p k → (lhs.denote ctx = rhs.denote ctx) = (p.denote' ctx = 0) := by @@ -653,7 +653,7 @@ private theorem eq_of_norm_eq_of_divCoeffs {ctx : Context} {p₁ p₂ : Poly} {k @[expose] noncomputable def norm_le_coeff_tight_cert (lhs rhs : Expr) (p : Poly) (k : Int) : Bool := let p' := lhs.sub rhs |>.norm - (k > 0 : Bool) |>.and' (p'.divCoeffs k |>.and' (p.beq' (p'.div k))) + (Int.blt' 0 k) |>.and' (p'.divCoeffs k |>.and' (p.beq' (p'.div k))) theorem norm_le_coeff_tight (ctx : Context) (lhs rhs : Expr) (p : Poly) (k : Int) : norm_le_coeff_tight_cert lhs rhs p k → (lhs.denote ctx ≤ rhs.denote ctx) = (p.denote' ctx ≤ 0) := by @@ -765,7 +765,7 @@ private theorem poly_eq_zero_eq_false (ctx : Context) {p : Poly} {k : Int} : p.d @[expose] noncomputable def unsatEqDivCoeffCert (lhs rhs : Expr) (k : Int) : Bool := let p := (lhs.sub rhs).norm - p.divCoeffs k |>.and' ((k > 0 : Bool) |>.and' (cmod p.getConst k < 0)) + p.divCoeffs k |>.and' (Int.blt' 0 k |>.and' (cmod p.getConst k < 0)) theorem eq_eq_false_of_divCoeff (ctx : Context) (lhs rhs : Expr) (k : Int) : unsatEqDivCoeffCert lhs rhs k → (lhs.denote ctx = rhs.denote ctx) = False := by simp [unsatEqDivCoeffCert] @@ -997,7 +997,7 @@ theorem le_norm (ctx : Context) (p₁ p₂ : Poly) (h : p₁.norm.beq' p₂) : p @[expose] noncomputable def le_coeff_cert (p₁ p₂ : Poly) (k : Int) : Bool := - (k > 0 : Bool).and' (p₁.divCoeffs k |>.and' (p₂.beq' (p₁.div k))) + Int.blt' 0 k |>.and' (p₁.divCoeffs k |>.and' (p₂.beq' (p₁.div k))) theorem le_coeff (ctx : Context) (p₁ p₂ : Poly) (k : Int) : le_coeff_cert p₁ p₂ k → p₁.denote' ctx ≤ 0 → p₂.denote' ctx ≤ 0 := by simp [le_coeff_cert] @@ -1042,12 +1042,12 @@ noncomputable def le_combine_coeff_cert (p₁ p₂ p₃ : Poly) (k : Int) : Bool let a₁ := p₁.leadCoeff.natAbs let a₂ := p₂.leadCoeff.natAbs let p := p₁.combine_mul_k a₂ a₁ p₂ - (k > 0 : Bool).and' (p.divCoeffs k |>.and' (p₃.beq' (p.div k))) + Int.blt' 0 k |>.and' (p.divCoeffs k |>.and' (p₃.beq' (p.div k))) theorem le_combine_coeff (ctx : Context) (p₁ p₂ p₃ : Poly) (k : Int) : le_combine_coeff_cert p₁ p₂ p₃ k → p₁.denote' ctx ≤ 0 → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by simp only [le_combine_coeff_cert, Bool.and'_eq_and, - Poly.beq'_eq, gt_iff_lt, Bool.and_eq_true, decide_eq_true_eq, and_imp] + Poly.beq'_eq, Bool.and_eq_true, and_imp] let a₁ := p₁.leadCoeff.natAbs let a₂ := p₂.leadCoeff.natAbs generalize h : (p₁.combine_mul_k a₂ a₁ p₂) = p @@ -1056,7 +1056,7 @@ theorem le_combine_coeff (ctx : Context) (p₁ p₂ p₃ : Poly) (k : Int) simp only [le_combine_cert, Poly.beq'_eq] at this have aux₁ := this h.symm h₄ h₅ have := le_coeff ctx p p₃ k - simp only [le_coeff_cert, Bool.and'_eq_and, Poly.beq'_eq, gt_iff_lt, Bool.and_eq_true, decide_eq_true_eq, and_imp] at this + simp only [le_coeff_cert, Bool.and'_eq_and, Poly.beq'_eq, Bool.and_eq_true, and_imp] at this exact this h₁ h₂ h₃ aux₁ theorem le_unsat (ctx : Context) (p : Poly) : p.isUnsatLe → p.denote' ctx ≤ 0 → False := by @@ -1081,7 +1081,7 @@ theorem eq_unsat (ctx : Context) (p : Poly) : p.isUnsatEq → p.denote' ctx = 0 @[expose] noncomputable def eq_unsat_coeff_cert (p : Poly) (k : Int) : Bool := - p.divCoeffs k |>.and' ((k > 0 : Bool).and' (cmod p.getConst k < 0)) + p.divCoeffs k |>.and' (Int.blt' 0 k |>.and' (cmod p.getConst k < 0)) theorem eq_unsat_coeff (ctx : Context) (p : Poly) (k : Int) : eq_unsat_coeff_cert p k → p.denote' ctx = 0 → False := by simp [eq_unsat_coeff_cert] @@ -1208,7 +1208,7 @@ theorem eq_eq_subst' (ctx : Context) (a b : Int) (p₁ : Poly) (p₂ : Poly) (p noncomputable def eq_le_subst_nonneg_cert (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : Bool := let a := p₁.coeff x let b := p₂.coeff x - (a ≥ 0 : Bool).and' (p₃.beq' (p₂.combine_mul_k a (-b) p₁)) + Int.ble' 0 a |>.and' (p₃.beq' (p₂.combine_mul_k a (-b) p₁)) theorem eq_le_subst_nonneg (ctx : Context) (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : eq_le_subst_nonneg_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by @@ -1224,7 +1224,7 @@ theorem eq_le_subst_nonneg (ctx : Context) (x : Var) (p₁ : Poly) (p₂ : Poly) noncomputable def eq_le_subst_nonpos_cert (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : Bool := let a := p₁.coeff x let b := p₂.coeff x - (a ≤ 0 : Bool).and' (p₃.beq' (p₁.combine_mul_k b (-a) p₂)) + Int.ble' a 0 |>.and' (p₃.beq' (p₁.combine_mul_k b (-a) p₂)) theorem eq_le_subst_nonpos (ctx : Context) (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : eq_le_subst_nonpos_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by @@ -1933,12 +1933,12 @@ noncomputable def dvd_le_tight_cert (d : Int) (p₁ p₂ p₃ : Poly) : Bool := let b₁ := p₁.getConst let b₂ := p₂.getConst let p := p₁.addConst_k (-b₁) - (d > 0 : Bool) |>.and' (p₂.beq' (p.addConst_k b₂) |>.and' (p₃.beq' (p.addConst_k (b₁ - d*((b₁ - b₂)/d))))) + Int.blt' 0 d |>.and' (p₂.beq' (p.addConst_k b₂) |>.and' (p₃.beq' (p.addConst_k (b₁ - d*((b₁ - b₂)/d))))) theorem dvd_le_tight (ctx : Context) (d : Int) (p₁ p₂ p₃ : Poly) : dvd_le_tight_cert d p₁ p₂ p₃ → d ∣ p₁.denote' ctx → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by - simp only [dvd_le_tight_cert, gt_iff_lt, Bool.and'_eq_and, Poly.beq'_eq, Bool.and_eq_true, - Poly.addConst_k_eq_addConst, decide_eq_true_eq, and_imp] + simp only [dvd_le_tight_cert, Bool.and'_eq_and, Poly.beq'_eq, Bool.and_eq_true, + Poly.addConst_k_eq_addConst, Int.blt'_eq_true, and_imp] generalize p₂.getConst = b₂ intro hd _ _; subst p₂ p₃ have := eq_neg_addConst_add ctx p₁ @@ -1957,7 +1957,7 @@ noncomputable def dvd_neg_le_tight_cert (d : Int) (p₁ p₂ p₃ : Poly) : Bool let p := p₁.addConst_k (-b₁) let b₁ := -b₁ let p := p.mul_k (-1) - (d > 0 : Bool) |>.and' (p₂.beq' (p.addConst_k b₂) |>.and' (p₃.beq' (p.addConst_k (b₁ - d*((b₁ - b₂)/d))))) + Int.blt' 0 d |>.and' (p₂.beq' (p.addConst_k b₂) |>.and' (p₃.beq' (p.addConst_k (b₁ - d*((b₁ - b₂)/d))))) theorem Poly.mul_minus_one_getConst_eq (p : Poly) : (p.mul (-1)).getConst = -p.getConst := by simp [Poly.mul] @@ -1965,8 +1965,8 @@ theorem Poly.mul_minus_one_getConst_eq (p : Poly) : (p.mul (-1)).getConst = -p.g theorem dvd_neg_le_tight (ctx : Context) (d : Int) (p₁ p₂ p₃ : Poly) : dvd_neg_le_tight_cert d p₁ p₂ p₃ → d ∣ p₁.denote' ctx → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by - simp only [dvd_neg_le_tight_cert, gt_iff_lt, Poly.beq'_eq, Bool.and'_eq_and, Bool.and_eq_true, - decide_eq_true_eq, and_imp] + simp only [dvd_neg_le_tight_cert, Poly.beq'_eq, Bool.and'_eq_and, Bool.and_eq_true, + Int.blt'_eq_true, and_imp] generalize p₂.getConst = b₂ intro hd _ _; subst p₂ p₃ simp only [Poly.denote'_eq_denote, Int.reduceNeg, Poly.addConst_k_eq_addConst, Poly.denote_addConst, Poly.denote_mul, Poly.mul_k_eq_mul, diff --git a/src/Init/Grind/Ring/Poly.lean b/src/Init/Grind/Ring/Poly.lean index 6a1976335f..15d8415803 100644 --- a/src/Init/Grind/Ring/Poly.lean +++ b/src/Init/Grind/Ring/Poly.lean @@ -7,6 +7,7 @@ module prelude public import Init.Data.Nat.Lemmas +public import Init.Data.Int.LemmasAux public import Init.Data.Hashable public import all Init.Data.Ord public import Init.Data.RArray @@ -41,23 +42,25 @@ def Var.denote {α} (ctx : Context α) (v : Var) : α := ctx.get v @[expose] -def denoteInt {α} [Ring α] (k : Int) : α := - bif k < 0 then - - OfNat.ofNat (α := α) k.natAbs - else - OfNat.ofNat (α := α) k.natAbs +noncomputable def denoteInt {α} [Ring α] (k : Int) : α := + Bool.rec + (OfNat.ofNat (α := α) k.natAbs) + (- OfNat.ofNat (α := α) k.natAbs) + (Int.blt' k 0) @[expose] -def Expr.denote {α} [Ring α] (ctx : Context α) : Expr → α - | .add a b => denote ctx a + denote ctx b - | .sub a b => denote ctx a - denote ctx b - | .mul a b => denote ctx a * denote ctx b - | .neg a => -denote ctx a - | .num k => denoteInt k - | .natCast k => NatCast.natCast (R := α) k - | .intCast k => IntCast.intCast (R := α) k - | .var v => v.denote ctx - | .pow a k => denote ctx a ^ k +noncomputable def Expr.denote {α} [Ring α] (ctx : Context α) (e : Expr) : α := + Expr.rec + (fun k => denoteInt k) + (fun k => NatCast.natCast (R := α) k) + (fun k => IntCast.intCast (R := α) k) + (fun x => x.denote ctx) + (fun _ ih => - ih) + (fun _ _ ih₁ ih₂ => ih₁ + ih₂) + (fun _ _ ih₁ ih₂ => ih₁ - ih₂) + (fun _ _ ih₁ ih₂ => ih₁ * ih₂) + (fun _ k ih => ih ^ k) + e structure Power where x : Var @@ -797,7 +800,7 @@ q₁*(lhs₁ - rhs₁) + ... + qₙ*(lhsₙ - rhsₙ) ``` -/ @[expose] -def NullCert.denote {α} [CommRing α] (ctx : Context α) : NullCert → α +noncomputable def NullCert.denote {α} [CommRing α] (ctx : Context α) : NullCert → α | .empty => 0 | .add q lhs rhs nc => (q.denote ctx)*(lhs.denote ctx - rhs.denote ctx) + nc.denote ctx @@ -840,9 +843,9 @@ open Ring hiding sub_eq_add_neg open CommSemiring theorem denoteInt_eq {α} [CommRing α] (k : Int) : denoteInt (α := α) k = k := by - simp [denoteInt, cond_eq_if] <;> split - next h => rw [ofNat_eq_natCast, ← intCast_natCast, ← intCast_neg, ← Int.eq_neg_natAbs_of_nonpos (Int.le_of_lt h)] - next h => rw [ofNat_eq_natCast, ← intCast_natCast, ← Int.eq_natAbs_of_nonneg (Int.le_of_not_gt h)] + simp [denoteInt] <;> cases h : k.blt' 0 <;> simp <;> simp at h + next h => rw [ofNat_eq_natCast, ← intCast_natCast, ← Int.eq_natAbs_of_nonneg h] + next h => rw [ofNat_eq_natCast, ← intCast_natCast, ← Ring.intCast_neg, ← Int.eq_neg_natAbs_of_nonpos (Int.le_of_lt h)] theorem Power.denote_eq {α} [Semiring α] (ctx : Context α) (p : Power) : p.denote ctx = p.x.denote ctx ^ p.k := by @@ -1596,14 +1599,14 @@ theorem not_lt_norm {α} [CommRing α] [LinearOrder α] [OrderedRing α] (ctx : theorem not_le_norm' {α} [CommRing α] [Preorder α] [OrderedRing α] (ctx : Context α) (lhs rhs : Expr) (p : Poly) : core_cert lhs rhs p → ¬ lhs.denote ctx ≤ rhs.denote ctx → ¬ p.denoteAsIntModule ctx ≤ 0 := by simp [core_cert, Poly.denoteAsIntModule_eq_denote]; intro _ h₁; subst p; simp [Expr.denote_toPoly, Expr.denote]; intro h - replace h := add_le_right (rhs.denote ctx) h + replace h : rhs.denote ctx + (lhs.denote ctx - rhs.denote ctx) ≤ _ := add_le_right (rhs.denote ctx) h rw [sub_eq_add_neg, add_left_comm, ← sub_eq_add_neg, sub_self] at h; simp [add_zero] at h contradiction theorem not_lt_norm' {α} [CommRing α] [Preorder α] [OrderedRing α] (ctx : Context α) (lhs rhs : Expr) (p : Poly) : core_cert lhs rhs p → ¬ lhs.denote ctx < rhs.denote ctx → ¬ p.denoteAsIntModule ctx < 0 := by simp [core_cert, Poly.denoteAsIntModule_eq_denote]; intro _ h₁; subst p; simp [Expr.denote_toPoly, Expr.denote]; intro h - replace h := add_lt_right (rhs.denote ctx) h + replace h : rhs.denote ctx + (lhs.denote ctx - rhs.denote ctx) < _ := add_lt_right (rhs.denote ctx) h rw [sub_eq_add_neg, add_left_comm, ← sub_eq_add_neg, sub_self] at h; simp [add_zero] at h contradiction