feat: add BitVec.umulOverflow and BitVec.smulOverflow definitions and additional theorems (#7659)

This PR adds SMT-LIB operators to detect overflow
`BitVec.(umul_overflow, smul_overflow)`, according to the definitions
[here](https://github.com/SMT-LIB/SMT-LIB-2/blob/2.7/Theories/FixedSizeBitVectors.smt2),
and the theorems proving equivalence of such definitions with the
`BitVec` library functions (`umulOverflow_eq`, `smulOverflow_eq`).
Support theorems for these proofs are `BitVec.toInt_one_of_lt,
BitVec.toInt_mul_toInt_lt, BitVec.le_toInt_mul_toInt,
BitVec.toNat_mul_toNat_lt, BitVec.two_pow_le_toInt_mul_toInt_iff,
BitVec.toInt_mul_toInt_lt_neg_two_pow_iff` and `Int.neg_mul_le_mul,
Int.bmod_eq_self_of_le_mul_two, Int.mul_le_mul_of_natAbs_le,
Int.mul_le_mul_of_le_of_le_of_nonneg_of_nonpos, Int.pow_lt_pow`. The PR
also includes a set of tests.

Co-authored by @tobiasgrosser.

---------

Co-authored-by: Tobias Grosser <tobias@grosser.es>
Co-authored-by: Tobias Grosser <github@grosser.es>
Co-authored-by: Siddharth <siddu.druid@gmail.com>
This commit is contained in:
Luisa Cicolini 2025-04-03 09:42:52 +01:00 committed by GitHub
parent bb6bfdba37
commit e59d070af1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 199 additions and 1 deletions

View file

@ -760,4 +760,19 @@ def reverse : {w : Nat} → BitVec w → BitVec w
| 0, x => x
| w + 1, x => concat (reverse (x.truncate w)) (x.msb)
/-- `umulOverflow x y` returns `true` if multiplying `x` and `y` results in *unsigned* overflow.
SMT-Lib name: `bvumulo`.
-/
def umulOverflow {w : Nat} (x y : BitVec w) : Bool := x.toNat * y.toNat ≥ 2 ^ w
/-- `smulOverflow x y` returns `true` if multiplying `x` and `y` results in *signed* overflow,
treating `x` and `y` as 2's complement signed bitvectors.
SMT-Lib name: `bvsmulo`.
-/
def smulOverflow {w : Nat} (x y : BitVec w) : Bool :=
(x.toInt * y.toInt ≥ 2 ^ (w - 1)) || (x.toInt * y.toInt < - 2 ^ (w - 1))
end BitVec

View file

@ -7,6 +7,7 @@ prelude
import Init.Data.BitVec.Folds
import Init.Data.Nat.Mod
import Init.Data.Int.LemmasAux
import Init.Data.BitVec.Lemmas
/-!
# Bit blasting of bitvectors
@ -1358,6 +1359,31 @@ theorem negOverflow_eq {w : Nat} (x : BitVec w) :
simp only [toInt_intMin, Nat.add_one_sub_one, Int.ofNat_emod, Int.neg_inj]
rw_mod_cast [Nat.mod_eq_of_lt (by simp [Nat.pow_lt_pow_succ])]
theorem umulOverflow_eq {w : Nat} (x y : BitVec w) :
umulOverflow x y =
(0 < w && BitVec.twoPow (w * 2) w ≤ x.zeroExtend (w * 2) * y.zeroExtend (w * 2)) := by
simp only [umulOverflow, toNat_twoPow, le_def, toNat_mul, toNat_setWidth, mod_mul_mod]
rcases w with _|w
· simp [of_length_zero, toInt_zero, mul_mod_mod]
· simp only [ge_iff_le, show 0 < w + 1 by omega, decide_true, mul_mod_mod, Bool.true_and,
decide_eq_decide]
rw [Nat.mod_eq_of_lt BitVec.toNat_mul_toNat_lt, Nat.mod_eq_of_lt]
have := Nat.pow_lt_pow_of_lt (a := 2) (n := w + 1) (m := (w + 1) * 2)
omega
theorem smulOverflow_eq {w : Nat} (x y : BitVec w) :
smulOverflow x y =
(0 < w &&
((signExtend (w * 2) (intMax w)).slt (signExtend (w * 2) x * signExtend (w * 2) y) ||
(signExtend (w * 2) x * signExtend (w * 2) y).slt (signExtend (w * 2) (intMin w)))) := by
simp only [smulOverflow]
rcases w with _|w
· simp [of_length_zero, toInt_zero]
· have h₁ := BitVec.two_pow_le_toInt_mul_toInt_iff (x := x) (y := y)
have h₂ := BitVec.toInt_mul_toInt_lt_neg_two_pow_iff (x := x) (y := y)
simp only [Nat.add_one_sub_one] at h₁ h₂
simp [h₁, h₂]
/- ### umod -/
theorem getElem_umod {n d : BitVec w} (hi : i < w) :

View file

@ -600,6 +600,14 @@ theorem toInt_nonneg_of_msb_false {x : BitVec w} (h : x.msb = false) : 0 ≤ x.t
have : 2 * x.toNat < 2 ^ w := msb_eq_false_iff_two_mul_lt.mp h
omega
@[simp] theorem toInt_one_of_lt {w : Nat} (h : 1 < w) : (1#w).toInt = 1 := by
rw [toInt_eq_msb_cond]
simp only [msb_one, show w ≠ 1 by omega, decide_false, Bool.false_eq_true, ↓reduceIte,
toNat_ofNat, Int.ofNat_emod]
norm_cast
apply Nat.mod_eq_of_lt
apply Nat.one_lt_two_pow (by omega)
/-- Prove equality of bitvectors in terms of nat operations. -/
theorem eq_of_toInt_eq {x y : BitVec n} : x.toInt = y.toInt → x = y := by
intro eq
@ -4582,6 +4590,30 @@ theorem udiv_twoPow_eq_of_lt {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : x
have : 2^k < 2^w := Nat.pow_lt_pow_of_lt (by decide) hk
simp [bitvec_to_nat, Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt this]
theorem toInt_mul_toInt_le {x y : BitVec w} : x.toInt * y.toInt ≤ 2 ^ (w * 2 - 2) := by
rcases w with _|w
· simp [of_length_zero]
· have xlt := two_mul_toInt_lt (x := x); have xle := le_two_mul_toInt (x := x)
have ylt := two_mul_toInt_lt (x := y); have yle := le_two_mul_toInt (x := y)
have h : 2 ^ ((w + 1) * 2 - 2) = 2 ^ ((w + 1) - 1) * 2 ^ ((w + 1) - 1) := by
rw [← Nat.pow_add, ←Nat.mul_two, Nat.mul_comm (m := 2) (n := (w + 1) - 1),
Nat.mul_sub_one, Nat.mul_comm]
rw_mod_cast [h]
rw [← Nat.two_pow_pred_mul_two (by omega), Int.natCast_mul] at xlt ylt xle yle
exact Int.mul_le_mul_of_natAbs_le (by omega) (by omega)
theorem le_toInt_mul_toInt {x y : BitVec w} : - (2 ^ (w * 2 - 2)) ≤ x.toInt * y.toInt := by
rcases w with _|w
· simp [of_length_zero]
· have xlt := two_mul_toInt_lt (x := x); have xle := le_two_mul_toInt (x := x)
have ylt := two_mul_toInt_lt (x := y); have yle := le_two_mul_toInt (x := y)
have h : 2 ^ ((w + 1) * 2 - 2) = 2 ^ ((w + 1) - 1) * 2 ^ ((w + 1) - 1) := by
rw [← Nat.pow_add, ←Nat.mul_two, Nat.mul_comm (m := 2) (n := (w + 1) - 1),
Nat.mul_sub_one, Nat.mul_comm]
rw_mod_cast [h]
rw [← Nat.two_pow_pred_mul_two (by omega), Int.natCast_mul] at xlt ylt xle yle
exact Int.neg_mul_le_mul (by omega) (by omega) (by omega) (by omega)
theorem shiftLeft_neg {x : BitVec w} {y : Nat} :
(-x) <<< y = - (x <<< y) := by
rw [shiftLeft_eq_mul_twoPow, shiftLeft_eq_mul_twoPow, BitVec.neg_mul]
@ -4945,6 +4977,10 @@ theorem toNat_mul_of_lt {w} {x y : BitVec w} (h : x.toNat * y.toNat < 2^w) :
(x * y).toNat = x.toNat * y.toNat := by
rw [BitVec.toNat_mul, Nat.mod_eq_of_lt h]
theorem toNat_mul_toNat_lt {x y : BitVec w} : x.toNat * y.toNat < 2 ^ (w * 2) := by
have := BitVec.isLt x; have := BitVec.isLt y
simp only [Nat.mul_two, Nat.pow_add]
exact Nat.mul_lt_mul_of_le_of_lt (by omega) (by omega) (by omega)
/--
`x ≤ y + z` if and only if `x - z ≤ y`
@ -4969,6 +5005,41 @@ theorem sub_le_sub_iff_le {x y z : BitVec w} (hxz : z ≤ x) (hyz : z ≤ y) :
BitVec.toNat_sub_of_le (by rw [BitVec.le_def]; omega)]
omega
theorem two_pow_le_toInt_mul_toInt_iff {x y : BitVec w} :
2 ^ (w - 1) ≤ x.toInt * y.toInt ↔
(signExtend (w * 2) (intMax w)).slt (signExtend (w * 2) x * signExtend (w * 2) y) := by
rcases w with _|w
· simp [of_length_zero]
· have := Int.pow_lt_pow_of_lt (a := 2) (b := (w + 1) * 2 - 2) (c := (w + 1) * 2 - 1) (by omega)
have := @BitVec.le_toInt_mul_toInt (w + 1) x y
have := @BitVec.toInt_mul_toInt_le (w + 1) x y
simp only [Nat.add_one_sub_one, BitVec.slt, intMax, ofNat_eq_ofNat, toInt_mul,
decide_eq_true_eq]
repeat rw [BitVec.toInt_signExtend_of_le (by omega)]
simp only [show BitVec.twoPow (w + 1) w - 1#(w + 1) = BitVec.intMax (w + 1) by simp [intMax],
toInt_intMax, Nat.add_one_sub_one]
push_cast
rw [← Nat.two_pow_pred_add_two_pow_pred (by omega),
Int.bmod_eq_self_of_le_mul_two (by rw [← Nat.mul_two]; push_cast; omega)
(by rw [← Nat.mul_two]; push_cast; omega)]
omega
theorem toInt_mul_toInt_lt_neg_two_pow_iff {x y : BitVec w} :
x.toInt * y.toInt < - 2 ^ (w - 1) ↔
(signExtend (w * 2) x * signExtend (w * 2) y).slt (signExtend (w * 2) (intMin w)) := by
rcases w with _|w
· simp [of_length_zero]
· have := Int.pow_lt_pow_of_lt (a := 2) (b := (w + 1) * 2 - 2) (c := (w + 1) * 2 - 1) (by omega)
have := @BitVec.le_toInt_mul_toInt (w + 1) x y
have := @BitVec.toInt_mul_toInt_le (w + 1) x y
simp only [BitVec.slt, toInt_mul, intMin, Nat.add_one_sub_one, decide_eq_true_eq]
repeat rw [BitVec.toInt_signExtend_of_le (by omega)]
simp only [toInt_twoPow, show ¬w + 1 ≤ w by omega, ↓reduceIte]
push_cast
rw [← Nat.two_pow_pred_add_two_pow_pred (by omega),
Int.bmod_eq_self_of_le_mul_two (by rw [← Nat.mul_two]; push_cast; omega)
(by rw [← Nat.mul_two]; push_cast; omega)]
/-! ### neg -/
theorem msb_eq_toInt {x : BitVec w}:

View file

@ -151,4 +151,59 @@ theorem bmod_bmod_of_dvd {a : Int} {n m : Nat} (hnm : n m) :
obtain ⟨k, rfl⟩ := hnm
simp [Int.mul_assoc]
theorem bmod_eq_self_of_le_mul_two {x : Int} {y : Nat} (hle : -y ≤ x * 2) (hlt : x * 2 < y) :
x.bmod y = x := by
apply bmod_eq_self_of_le (by omega) (by omega)
theorem mul_le_mul_of_natAbs_le {x y : Int} {s t : Nat} (hx : x.natAbs ≤ s) (hy : y.natAbs ≤ t) :
x * y ≤ s * t := by
by_cases 0 < s ∧ 0 < t
· have := Nat.mul_pos (n := s) (m := t) (by omega) (by omega)
by_cases hx : 0 < x <;> by_cases hy : 0 < y
· apply Int.mul_le_mul <;> omega
· have : x * y ≤ 0 := Int.mul_nonpos_of_nonneg_of_nonpos (by omega) (by omega); omega
· have : x * y ≤ 0 := Int.mul_nonpos_of_nonpos_of_nonneg (by omega) (by omega); omega
· have : -x * -y ≤ s * t := Int.mul_le_mul (by omega) (by omega) (by omega) (by omega)
simp [Int.neg_mul_neg] at this
norm_cast
· have : (x = 0 y = 0) → x * y = 0 := by simp [Int.mul_eq_zero]
norm_cast
omega
/--
This is a generalization of `a ≤ c` and `b ≤ d` implying `a * b ≤ c * d` for natural numbers,
appropriately generalized to integers when `b` is nonnegative and `c` is nonpositive.
-/
theorem mul_le_mul_of_le_of_le_of_nonneg_of_nonpos {a b c d : Int}
(hac : a ≤ c) (hbd : d ≤ b) (hb : 0 ≤ b) (hc : c ≤ 0) : a * b ≤ c * d :=
Int.le_trans (Int.mul_le_mul_of_nonneg_right hac hb) (Int.mul_le_mul_of_nonpos_left hc hbd)
theorem mul_le_mul_of_le_of_le_of_nonneg_of_nonneg {a b c d : Int}
(hac : a ≤ c) (hbd : b ≤ d) (hb : 0 ≤ b) (hc : 0 ≤ c) : a * b ≤ c * d :=
Int.le_trans (Int.mul_le_mul_of_nonneg_right hac hb) (Int.mul_le_mul_of_nonneg_left hbd hc)
theorem mul_le_mul_of_le_of_le_of_nonpos_of_nonpos {a b c d : Int}
(hac : c ≤ a) (hbd : d ≤ b) (hb : b ≤ 0) (hc : c ≤ 0) : a * b ≤ c * d :=
Int.le_trans (Int.mul_le_mul_of_nonpos_right hac hb) (Int.mul_le_mul_of_nonpos_left hc hbd)
theorem mul_le_mul_of_le_of_le_of_nonpos_of_nonneg {a b c d : Int}
(hac : c ≤ a) (hbd : b ≤ d) (hb : b ≤ 0) (hc : 0 ≤ c) : a * b ≤ c * d :=
Int.le_trans (Int.mul_le_mul_of_nonpos_right hac hb) (Int.mul_le_mul_of_nonneg_left hbd hc)
/--
A corollary of |s| ≤ x, and |t| ≤ y, then |s * t| ≤ x * y,
-/
theorem neg_mul_le_mul {x y : Int} {s t : Nat} (lbx : -s ≤ x) (ubx : x < s) (lby : -t ≤ y) (uby : y < t) :
-(s * t) ≤ x * y := by
have := Nat.mul_pos (n := s) (m := t) (by omega) (by omega)
by_cases 0 ≤ x <;> by_cases 0 ≤ y
· have : 0 ≤ x * y := by apply Int.mul_nonneg <;> omega
norm_cast
omega
· rw [Int.mul_comm (a := x), Int.mul_comm (a := (s : Int)), ← Int.neg_mul]; apply Int.mul_le_mul_of_le_of_le_of_nonneg_of_nonpos <;> omega
· rw [← Int.neg_mul]; apply Int.mul_le_mul_of_le_of_le_of_nonneg_of_nonpos <;> omega
· have : 0 < x * y := by apply Int.mul_pos_of_neg_of_neg <;> omega
norm_cast
omega
end Int

View file

@ -56,4 +56,11 @@ protected theorem two_pow_pred_sub_two_pow' {w : Nat} (h : 0 < w) :
rw [← Nat.two_pow_pred_add_two_pow_pred h]
simp [h]
theorem pow_lt_pow_of_lt {a : Int} {b c : Nat} (ha : 1 < a) (hbc : b < c):
a ^ b < a ^ c := by
rw [← Int.toNat_of_nonneg (a := a) (by omega), ← Int.natCast_pow, ← Int.natCast_pow]
have := Nat.pow_lt_pow_of_lt (a := a.toNat) (m := c) (n := b)
simp only [Int.ofNat_lt]
omega
end Int

View file

@ -166,6 +166,9 @@ end Constant
attribute [bv_normalize] BitVec.zero_and
attribute [bv_normalize] BitVec.and_zero
attribute [bv_normalize] BitVec.intMax
attribute [bv_normalize] BitVec.intMin
-- Used in simproc because of - normalization
theorem BitVec.ones_and (a : BitVec w) : (-1#w) &&& a = a := by
ext
@ -355,6 +358,8 @@ attribute [bv_normalize] BitVec.umod_eq_and
attribute [bv_normalize] BitVec.saddOverflow_eq
attribute [bv_normalize] BitVec.uaddOverflow_eq
attribute [bv_normalize] BitVec.negOverflow_eq
attribute [bv_normalize] BitVec.umulOverflow_eq
attribute [bv_normalize] BitVec.smulOverflow_eq
attribute [bv_normalize] BitVec.usubOverflow_eq
attribute [bv_normalize] BitVec.ssubOverflow_eq

View file

@ -85,7 +85,17 @@ example (x y : BitVec 16) : BitVec.uaddOverflow x y = (x.setWidth (17) + y.setWi
example (x y : BitVec 16) : BitVec.saddOverflow x y = (x.msb = y.msb ∧ ¬(x + y).msb = x.msb) := by bv_normalize
example (x y : BitVec w) : BitVec.uaddOverflow x y = (x.setWidth (w + 1) + y.setWidth (w + 1)).msb := by bv_normalize
example (x y : BitVec w) : BitVec.saddOverflow x y = (x.msb = y.msb ∧ ¬(x + y).msb = x.msb) := by bv_normalize
example (x y : BitVec 16) : BitVec.umulOverflow x y = (BitVec.twoPow 32 16 ≤ x.zeroExtend (32) * y.zeroExtend (32)) := by bv_normalize
example (x y : BitVec 16) : BitVec.smulOverflow x y =
((BitVec.signExtend (16 * 2) (BitVec.intMax 16)).slt (BitVec.signExtend (16 * 2) x * BitVec.signExtend (16 * 2) y) ||
(BitVec.signExtend (16 * 2) x * BitVec.signExtend (16 * 2) y).slt (BitVec.signExtend (16 * 2) (BitVec.intMin 16))) :=
by bv_normalize
example (x y : BitVec w) : BitVec.umulOverflow x y = (0 < w && BitVec.twoPow (w * 2) w ≤ x.zeroExtend (w * 2) * y.zeroExtend (w * 2)) := by bv_normalize
example (x y : BitVec w) : BitVec.smulOverflow x y =
(decide (0 < w) &&
((BitVec.signExtend (w * 2) (BitVec.intMax w)).slt (BitVec.signExtend (w * 2) x * BitVec.signExtend (w * 2) y) ||
(BitVec.signExtend (w * 2) x * BitVec.signExtend (w * 2) y).slt (BitVec.signExtend (w * 2) (BitVec.intMin w))))
:= by bv_normalize
-- not_neg
example {x : BitVec 16} : ~~~(-x) = x + (-1#16) := by bv_normalize
@ -623,6 +633,15 @@ example {x : BitVec 16} : (x = BitVec.allOnes 16) → (BitVec.uaddOverflow x x)
example {x : BitVec 64} : (x = BitVec.intMin 64) ↔ (BitVec.negOverflow x) := by bv_decide
example {x : BitVec 16} : (x = BitVec.allOnes 16) → (BitVec.umulOverflow x x) := by bv_decide
example {x : BitVec 8} : (x = -32#8) → (BitVec.smulOverflow x x) := by bv_decide
example {x : BitVec 8} : (x = 0#8) → (¬ BitVec.smulOverflow x x) := by bv_decide
example {x : BitVec 8} : (x ≥ -2#8) → (¬ BitVec.smulOverflow x x) := by bv_decide
example {x : BitVec 8} : (x < 12#8) → (¬ BitVec.smulOverflow x x) := by bv_decide
example {x y : BitVec 64} : ((x = 0#64) ∧ (y = BitVec.allOnes 64)) → (BitVec.usubOverflow x y) := by bv_decide