From e59d070af14a67b3e36ea657dc617b02fb32127b Mon Sep 17 00:00:00 2001 From: Luisa Cicolini <48860705+luisacicolini@users.noreply.github.com> Date: Thu, 3 Apr 2025 09:42:52 +0100 Subject: [PATCH] feat: add `BitVec.umulOverflow` and `BitVec.smulOverflow` definitions and additional theorems (#7659) This PR adds SMT-LIB operators to detect overflow `BitVec.(umul_overflow, smul_overflow)`, according to the definitions [here](https://github.com/SMT-LIB/SMT-LIB-2/blob/2.7/Theories/FixedSizeBitVectors.smt2), and the theorems proving equivalence of such definitions with the `BitVec` library functions (`umulOverflow_eq`, `smulOverflow_eq`). Support theorems for these proofs are `BitVec.toInt_one_of_lt, BitVec.toInt_mul_toInt_lt, BitVec.le_toInt_mul_toInt, BitVec.toNat_mul_toNat_lt, BitVec.two_pow_le_toInt_mul_toInt_iff, BitVec.toInt_mul_toInt_lt_neg_two_pow_iff` and `Int.neg_mul_le_mul, Int.bmod_eq_self_of_le_mul_two, Int.mul_le_mul_of_natAbs_le, Int.mul_le_mul_of_le_of_le_of_nonneg_of_nonpos, Int.pow_lt_pow`. The PR also includes a set of tests. Co-authored by @tobiasgrosser. --------- Co-authored-by: Tobias Grosser Co-authored-by: Tobias Grosser Co-authored-by: Siddharth --- src/Init/Data/BitVec/Basic.lean | 15 ++++ src/Init/Data/BitVec/Bitblast.lean | 26 +++++++ src/Init/Data/BitVec/Lemmas.lean | 71 +++++++++++++++++++ src/Init/Data/Int/LemmasAux.lean | 55 ++++++++++++++ src/Init/Data/Int/Pow.lean | 7 ++ src/Std/Tactic/BVDecide/Normalize/BitVec.lean | 5 ++ tests/lean/run/bv_decide_rewriter.lean | 21 +++++- 7 files changed, 199 insertions(+), 1 deletion(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index dd1db9ba31..adfb22aea0 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -760,4 +760,19 @@ def reverse : {w : Nat} → BitVec w → BitVec w | 0, x => x | w + 1, x => concat (reverse (x.truncate w)) (x.msb) +/-- `umulOverflow x y` returns `true` if multiplying `x` and `y` results in *unsigned* overflow. + + SMT-Lib name: `bvumulo`. +-/ +def umulOverflow {w : Nat} (x y : BitVec w) : Bool := x.toNat * y.toNat ≥ 2 ^ w + +/-- `smulOverflow x y` returns `true` if multiplying `x` and `y` results in *signed* overflow, +treating `x` and `y` as 2's complement signed bitvectors. + + SMT-Lib name: `bvsmulo`. +-/ + +def smulOverflow {w : Nat} (x y : BitVec w) : Bool := + (x.toInt * y.toInt ≥ 2 ^ (w - 1)) || (x.toInt * y.toInt < - 2 ^ (w - 1)) + end BitVec diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 107b0d6b54..7ae15cf92f 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -7,6 +7,7 @@ prelude import Init.Data.BitVec.Folds import Init.Data.Nat.Mod import Init.Data.Int.LemmasAux +import Init.Data.BitVec.Lemmas /-! # Bit blasting of bitvectors @@ -1358,6 +1359,31 @@ theorem negOverflow_eq {w : Nat} (x : BitVec w) : simp only [toInt_intMin, Nat.add_one_sub_one, Int.ofNat_emod, Int.neg_inj] rw_mod_cast [Nat.mod_eq_of_lt (by simp [Nat.pow_lt_pow_succ])] +theorem umulOverflow_eq {w : Nat} (x y : BitVec w) : + umulOverflow x y = + (0 < w && BitVec.twoPow (w * 2) w ≤ x.zeroExtend (w * 2) * y.zeroExtend (w * 2)) := by + simp only [umulOverflow, toNat_twoPow, le_def, toNat_mul, toNat_setWidth, mod_mul_mod] + rcases w with _|w + · simp [of_length_zero, toInt_zero, mul_mod_mod] + · simp only [ge_iff_le, show 0 < w + 1 by omega, decide_true, mul_mod_mod, Bool.true_and, + decide_eq_decide] + rw [Nat.mod_eq_of_lt BitVec.toNat_mul_toNat_lt, Nat.mod_eq_of_lt] + have := Nat.pow_lt_pow_of_lt (a := 2) (n := w + 1) (m := (w + 1) * 2) + omega + +theorem smulOverflow_eq {w : Nat} (x y : BitVec w) : + smulOverflow x y = + (0 < w && + ((signExtend (w * 2) (intMax w)).slt (signExtend (w * 2) x * signExtend (w * 2) y) || + (signExtend (w * 2) x * signExtend (w * 2) y).slt (signExtend (w * 2) (intMin w)))) := by + simp only [smulOverflow] + rcases w with _|w + · simp [of_length_zero, toInt_zero] + · have h₁ := BitVec.two_pow_le_toInt_mul_toInt_iff (x := x) (y := y) + have h₂ := BitVec.toInt_mul_toInt_lt_neg_two_pow_iff (x := x) (y := y) + simp only [Nat.add_one_sub_one] at h₁ h₂ + simp [h₁, h₂] + /- ### umod -/ theorem getElem_umod {n d : BitVec w} (hi : i < w) : diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index cf2b7a24d1..5e99971830 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -600,6 +600,14 @@ theorem toInt_nonneg_of_msb_false {x : BitVec w} (h : x.msb = false) : 0 ≤ x.t have : 2 * x.toNat < 2 ^ w := msb_eq_false_iff_two_mul_lt.mp h omega +@[simp] theorem toInt_one_of_lt {w : Nat} (h : 1 < w) : (1#w).toInt = 1 := by + rw [toInt_eq_msb_cond] + simp only [msb_one, show w ≠ 1 by omega, decide_false, Bool.false_eq_true, ↓reduceIte, + toNat_ofNat, Int.ofNat_emod] + norm_cast + apply Nat.mod_eq_of_lt + apply Nat.one_lt_two_pow (by omega) + /-- Prove equality of bitvectors in terms of nat operations. -/ theorem eq_of_toInt_eq {x y : BitVec n} : x.toInt = y.toInt → x = y := by intro eq @@ -4582,6 +4590,30 @@ theorem udiv_twoPow_eq_of_lt {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : x have : 2^k < 2^w := Nat.pow_lt_pow_of_lt (by decide) hk simp [bitvec_to_nat, Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt this] +theorem toInt_mul_toInt_le {x y : BitVec w} : x.toInt * y.toInt ≤ 2 ^ (w * 2 - 2) := by + rcases w with _|w + · simp [of_length_zero] + · have xlt := two_mul_toInt_lt (x := x); have xle := le_two_mul_toInt (x := x) + have ylt := two_mul_toInt_lt (x := y); have yle := le_two_mul_toInt (x := y) + have h : 2 ^ ((w + 1) * 2 - 2) = 2 ^ ((w + 1) - 1) * 2 ^ ((w + 1) - 1) := by + rw [← Nat.pow_add, ←Nat.mul_two, Nat.mul_comm (m := 2) (n := (w + 1) - 1), + Nat.mul_sub_one, Nat.mul_comm] + rw_mod_cast [h] + rw [← Nat.two_pow_pred_mul_two (by omega), Int.natCast_mul] at xlt ylt xle yle + exact Int.mul_le_mul_of_natAbs_le (by omega) (by omega) + +theorem le_toInt_mul_toInt {x y : BitVec w} : - (2 ^ (w * 2 - 2)) ≤ x.toInt * y.toInt := by + rcases w with _|w + · simp [of_length_zero] + · have xlt := two_mul_toInt_lt (x := x); have xle := le_two_mul_toInt (x := x) + have ylt := two_mul_toInt_lt (x := y); have yle := le_two_mul_toInt (x := y) + have h : 2 ^ ((w + 1) * 2 - 2) = 2 ^ ((w + 1) - 1) * 2 ^ ((w + 1) - 1) := by + rw [← Nat.pow_add, ←Nat.mul_two, Nat.mul_comm (m := 2) (n := (w + 1) - 1), + Nat.mul_sub_one, Nat.mul_comm] + rw_mod_cast [h] + rw [← Nat.two_pow_pred_mul_two (by omega), Int.natCast_mul] at xlt ylt xle yle + exact Int.neg_mul_le_mul (by omega) (by omega) (by omega) (by omega) + theorem shiftLeft_neg {x : BitVec w} {y : Nat} : (-x) <<< y = - (x <<< y) := by rw [shiftLeft_eq_mul_twoPow, shiftLeft_eq_mul_twoPow, BitVec.neg_mul] @@ -4945,6 +4977,10 @@ theorem toNat_mul_of_lt {w} {x y : BitVec w} (h : x.toNat * y.toNat < 2^w) : (x * y).toNat = x.toNat * y.toNat := by rw [BitVec.toNat_mul, Nat.mod_eq_of_lt h] +theorem toNat_mul_toNat_lt {x y : BitVec w} : x.toNat * y.toNat < 2 ^ (w * 2) := by + have := BitVec.isLt x; have := BitVec.isLt y + simp only [Nat.mul_two, Nat.pow_add] + exact Nat.mul_lt_mul_of_le_of_lt (by omega) (by omega) (by omega) /-- `x ≤ y + z` if and only if `x - z ≤ y` @@ -4969,6 +5005,41 @@ theorem sub_le_sub_iff_le {x y z : BitVec w} (hxz : z ≤ x) (hyz : z ≤ y) : BitVec.toNat_sub_of_le (by rw [BitVec.le_def]; omega)] omega +theorem two_pow_le_toInt_mul_toInt_iff {x y : BitVec w} : + 2 ^ (w - 1) ≤ x.toInt * y.toInt ↔ + (signExtend (w * 2) (intMax w)).slt (signExtend (w * 2) x * signExtend (w * 2) y) := by + rcases w with _|w + · simp [of_length_zero] + · have := Int.pow_lt_pow_of_lt (a := 2) (b := (w + 1) * 2 - 2) (c := (w + 1) * 2 - 1) (by omega) + have := @BitVec.le_toInt_mul_toInt (w + 1) x y + have := @BitVec.toInt_mul_toInt_le (w + 1) x y + simp only [Nat.add_one_sub_one, BitVec.slt, intMax, ofNat_eq_ofNat, toInt_mul, + decide_eq_true_eq] + repeat rw [BitVec.toInt_signExtend_of_le (by omega)] + simp only [show BitVec.twoPow (w + 1) w - 1#(w + 1) = BitVec.intMax (w + 1) by simp [intMax], + toInt_intMax, Nat.add_one_sub_one] + push_cast + rw [← Nat.two_pow_pred_add_two_pow_pred (by omega), + Int.bmod_eq_self_of_le_mul_two (by rw [← Nat.mul_two]; push_cast; omega) + (by rw [← Nat.mul_two]; push_cast; omega)] + omega + +theorem toInt_mul_toInt_lt_neg_two_pow_iff {x y : BitVec w} : + x.toInt * y.toInt < - 2 ^ (w - 1) ↔ + (signExtend (w * 2) x * signExtend (w * 2) y).slt (signExtend (w * 2) (intMin w)) := by + rcases w with _|w + · simp [of_length_zero] + · have := Int.pow_lt_pow_of_lt (a := 2) (b := (w + 1) * 2 - 2) (c := (w + 1) * 2 - 1) (by omega) + have := @BitVec.le_toInt_mul_toInt (w + 1) x y + have := @BitVec.toInt_mul_toInt_le (w + 1) x y + simp only [BitVec.slt, toInt_mul, intMin, Nat.add_one_sub_one, decide_eq_true_eq] + repeat rw [BitVec.toInt_signExtend_of_le (by omega)] + simp only [toInt_twoPow, show ¬w + 1 ≤ w by omega, ↓reduceIte] + push_cast + rw [← Nat.two_pow_pred_add_two_pow_pred (by omega), + Int.bmod_eq_self_of_le_mul_two (by rw [← Nat.mul_two]; push_cast; omega) + (by rw [← Nat.mul_two]; push_cast; omega)] + /-! ### neg -/ theorem msb_eq_toInt {x : BitVec w}: diff --git a/src/Init/Data/Int/LemmasAux.lean b/src/Init/Data/Int/LemmasAux.lean index 9217b192ed..b69a53936f 100644 --- a/src/Init/Data/Int/LemmasAux.lean +++ b/src/Init/Data/Int/LemmasAux.lean @@ -151,4 +151,59 @@ theorem bmod_bmod_of_dvd {a : Int} {n m : Nat} (hnm : n ∣ m) : obtain ⟨k, rfl⟩ := hnm simp [Int.mul_assoc] +theorem bmod_eq_self_of_le_mul_two {x : Int} {y : Nat} (hle : -y ≤ x * 2) (hlt : x * 2 < y) : + x.bmod y = x := by + apply bmod_eq_self_of_le (by omega) (by omega) + +theorem mul_le_mul_of_natAbs_le {x y : Int} {s t : Nat} (hx : x.natAbs ≤ s) (hy : y.natAbs ≤ t) : + x * y ≤ s * t := by + by_cases 0 < s ∧ 0 < t + · have := Nat.mul_pos (n := s) (m := t) (by omega) (by omega) + by_cases hx : 0 < x <;> by_cases hy : 0 < y + · apply Int.mul_le_mul <;> omega + · have : x * y ≤ 0 := Int.mul_nonpos_of_nonneg_of_nonpos (by omega) (by omega); omega + · have : x * y ≤ 0 := Int.mul_nonpos_of_nonpos_of_nonneg (by omega) (by omega); omega + · have : -x * -y ≤ s * t := Int.mul_le_mul (by omega) (by omega) (by omega) (by omega) + simp [Int.neg_mul_neg] at this + norm_cast + · have : (x = 0 ∨ y = 0) → x * y = 0 := by simp [Int.mul_eq_zero] + norm_cast + omega + +/-- +This is a generalization of `a ≤ c` and `b ≤ d` implying `a * b ≤ c * d` for natural numbers, +appropriately generalized to integers when `b` is nonnegative and `c` is nonpositive. +-/ +theorem mul_le_mul_of_le_of_le_of_nonneg_of_nonpos {a b c d : Int} + (hac : a ≤ c) (hbd : d ≤ b) (hb : 0 ≤ b) (hc : c ≤ 0) : a * b ≤ c * d := + Int.le_trans (Int.mul_le_mul_of_nonneg_right hac hb) (Int.mul_le_mul_of_nonpos_left hc hbd) + +theorem mul_le_mul_of_le_of_le_of_nonneg_of_nonneg {a b c d : Int} + (hac : a ≤ c) (hbd : b ≤ d) (hb : 0 ≤ b) (hc : 0 ≤ c) : a * b ≤ c * d := + Int.le_trans (Int.mul_le_mul_of_nonneg_right hac hb) (Int.mul_le_mul_of_nonneg_left hbd hc) + +theorem mul_le_mul_of_le_of_le_of_nonpos_of_nonpos {a b c d : Int} + (hac : c ≤ a) (hbd : d ≤ b) (hb : b ≤ 0) (hc : c ≤ 0) : a * b ≤ c * d := + Int.le_trans (Int.mul_le_mul_of_nonpos_right hac hb) (Int.mul_le_mul_of_nonpos_left hc hbd) + +theorem mul_le_mul_of_le_of_le_of_nonpos_of_nonneg {a b c d : Int} + (hac : c ≤ a) (hbd : b ≤ d) (hb : b ≤ 0) (hc : 0 ≤ c) : a * b ≤ c * d := + Int.le_trans (Int.mul_le_mul_of_nonpos_right hac hb) (Int.mul_le_mul_of_nonneg_left hbd hc) + +/-- +A corollary of |s| ≤ x, and |t| ≤ y, then |s * t| ≤ x * y, +-/ +theorem neg_mul_le_mul {x y : Int} {s t : Nat} (lbx : -s ≤ x) (ubx : x < s) (lby : -t ≤ y) (uby : y < t) : + -(s * t) ≤ x * y := by + have := Nat.mul_pos (n := s) (m := t) (by omega) (by omega) + by_cases 0 ≤ x <;> by_cases 0 ≤ y + · have : 0 ≤ x * y := by apply Int.mul_nonneg <;> omega + norm_cast + omega + · rw [Int.mul_comm (a := x), Int.mul_comm (a := (s : Int)), ← Int.neg_mul]; apply Int.mul_le_mul_of_le_of_le_of_nonneg_of_nonpos <;> omega + · rw [← Int.neg_mul]; apply Int.mul_le_mul_of_le_of_le_of_nonneg_of_nonpos <;> omega + · have : 0 < x * y := by apply Int.mul_pos_of_neg_of_neg <;> omega + norm_cast + omega + end Int diff --git a/src/Init/Data/Int/Pow.lean b/src/Init/Data/Int/Pow.lean index fb34dd01ca..16c63d1945 100644 --- a/src/Init/Data/Int/Pow.lean +++ b/src/Init/Data/Int/Pow.lean @@ -56,4 +56,11 @@ protected theorem two_pow_pred_sub_two_pow' {w : Nat} (h : 0 < w) : rw [← Nat.two_pow_pred_add_two_pow_pred h] simp [h] +theorem pow_lt_pow_of_lt {a : Int} {b c : Nat} (ha : 1 < a) (hbc : b < c): + a ^ b < a ^ c := by + rw [← Int.toNat_of_nonneg (a := a) (by omega), ← Int.natCast_pow, ← Int.natCast_pow] + have := Nat.pow_lt_pow_of_lt (a := a.toNat) (m := c) (n := b) + simp only [Int.ofNat_lt] + omega + end Int diff --git a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean index 6fcc99ab6c..fef0384ea8 100644 --- a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean +++ b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean @@ -166,6 +166,9 @@ end Constant attribute [bv_normalize] BitVec.zero_and attribute [bv_normalize] BitVec.and_zero +attribute [bv_normalize] BitVec.intMax +attribute [bv_normalize] BitVec.intMin + -- Used in simproc because of - normalization theorem BitVec.ones_and (a : BitVec w) : (-1#w) &&& a = a := by ext @@ -355,6 +358,8 @@ attribute [bv_normalize] BitVec.umod_eq_and attribute [bv_normalize] BitVec.saddOverflow_eq attribute [bv_normalize] BitVec.uaddOverflow_eq attribute [bv_normalize] BitVec.negOverflow_eq +attribute [bv_normalize] BitVec.umulOverflow_eq +attribute [bv_normalize] BitVec.smulOverflow_eq attribute [bv_normalize] BitVec.usubOverflow_eq attribute [bv_normalize] BitVec.ssubOverflow_eq diff --git a/tests/lean/run/bv_decide_rewriter.lean b/tests/lean/run/bv_decide_rewriter.lean index 95602f8fef..851a966c54 100644 --- a/tests/lean/run/bv_decide_rewriter.lean +++ b/tests/lean/run/bv_decide_rewriter.lean @@ -85,7 +85,17 @@ example (x y : BitVec 16) : BitVec.uaddOverflow x y = (x.setWidth (17) + y.setWi example (x y : BitVec 16) : BitVec.saddOverflow x y = (x.msb = y.msb ∧ ¬(x + y).msb = x.msb) := by bv_normalize example (x y : BitVec w) : BitVec.uaddOverflow x y = (x.setWidth (w + 1) + y.setWidth (w + 1)).msb := by bv_normalize example (x y : BitVec w) : BitVec.saddOverflow x y = (x.msb = y.msb ∧ ¬(x + y).msb = x.msb) := by bv_normalize - +example (x y : BitVec 16) : BitVec.umulOverflow x y = (BitVec.twoPow 32 16 ≤ x.zeroExtend (32) * y.zeroExtend (32)) := by bv_normalize +example (x y : BitVec 16) : BitVec.smulOverflow x y = + ((BitVec.signExtend (16 * 2) (BitVec.intMax 16)).slt (BitVec.signExtend (16 * 2) x * BitVec.signExtend (16 * 2) y) || + (BitVec.signExtend (16 * 2) x * BitVec.signExtend (16 * 2) y).slt (BitVec.signExtend (16 * 2) (BitVec.intMin 16))) := + by bv_normalize +example (x y : BitVec w) : BitVec.umulOverflow x y = (0 < w && BitVec.twoPow (w * 2) w ≤ x.zeroExtend (w * 2) * y.zeroExtend (w * 2)) := by bv_normalize +example (x y : BitVec w) : BitVec.smulOverflow x y = + (decide (0 < w) && + ((BitVec.signExtend (w * 2) (BitVec.intMax w)).slt (BitVec.signExtend (w * 2) x * BitVec.signExtend (w * 2) y) || + (BitVec.signExtend (w * 2) x * BitVec.signExtend (w * 2) y).slt (BitVec.signExtend (w * 2) (BitVec.intMin w)))) + := by bv_normalize -- not_neg example {x : BitVec 16} : ~~~(-x) = x + (-1#16) := by bv_normalize @@ -623,6 +633,15 @@ example {x : BitVec 16} : (x = BitVec.allOnes 16) → (BitVec.uaddOverflow x x) example {x : BitVec 64} : (x = BitVec.intMin 64) ↔ (BitVec.negOverflow x) := by bv_decide +example {x : BitVec 16} : (x = BitVec.allOnes 16) → (BitVec.umulOverflow x x) := by bv_decide + +example {x : BitVec 8} : (x = -32#8) → (BitVec.smulOverflow x x) := by bv_decide + +example {x : BitVec 8} : (x = 0#8) → (¬ BitVec.smulOverflow x x) := by bv_decide + +example {x : BitVec 8} : (x ≥ -2#8) → (¬ BitVec.smulOverflow x x) := by bv_decide + +example {x : BitVec 8} : (x < 12#8) → (¬ BitVec.smulOverflow x x) := by bv_decide example {x y : BitVec 64} : ((x = 0#64) ∧ (y = BitVec.allOnes 64)) → (BitVec.usubOverflow x y) := by bv_decide