feat: associativity lemmas for BitVec.(umul, smul, uadd, sadd)Overflow (#8740)

This PR introduces associativity rules and preservation of `(umul, smul,
uadd, sadd)Overflow`flags.

---------

Co-authored-by: Siddharth <siddu.druid@gmail.com>
This commit is contained in:
Luisa Cicolini 2025-06-13 10:07:09 +01:00 committed by GitHub
parent f247f2bdd0
commit 300c22a4e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3354,6 +3354,17 @@ theorem toNat_add_of_not_uaddOverflow {x y : BitVec w} (h : ¬ uaddOverflow x y)
· simp only [uaddOverflow, ge_iff_le, decide_eq_true_eq, Nat.not_le] at h
rw [toNat_add, Nat.mod_eq_of_lt h]
/--
Unsigned addition overflow reassociation.
If `(x + y)` and `(y + z)` do not overflow, then `(x + y) + z` overflows iff `x + (y + z)` overflows.
-/
theorem uaddOverflow_assoc {x y z : BitVec w} (h : ¬ x.uaddOverflow y) (h' : ¬ y.uaddOverflow z) :
(x + y).uaddOverflow z = x.uaddOverflow (y + z) := by
simp only [uaddOverflow, ge_iff_le, decide_eq_true_eq, Nat.not_le] at h h'
simp only [uaddOverflow, toNat_add, ge_iff_le, decide_eq_decide]
repeat rw [Nat.mod_eq_of_lt (by omega)]
omega
protected theorem add_assoc (x y z : BitVec n) : x + y + z = x + (y + z) := by
apply eq_of_toNat_eq ; simp [Nat.add_assoc]
instance : Std.Associative (α := BitVec n) (· + ·) := ⟨BitVec.add_assoc⟩
@ -3392,6 +3403,20 @@ theorem toInt_add_of_not_saddOverflow {x y : BitVec w} (h : ¬ saddOverflow x y)
_root_.not_or, Int.not_le, Int.not_lt] at h
rw [toInt_add, Int.bmod_eq_of_le (by push_cast; omega) (by push_cast; omega)]
/--
Signed addition overflow reassociation.
If `(x + y)` and `(y + z)` do not overflow, then `(x + y) + z` overflows iff `x + (y + z)` overflows.
-/
theorem saddOverflow_assoc {x y z : BitVec w} (h : ¬ x.saddOverflow y) (h' : ¬ y.saddOverflow z) :
(x + y).saddOverflow z = x.saddOverflow (y + z) := by
rcases w with _|w
· simp [of_length_zero]
· simp only [saddOverflow, Nat.add_one_sub_one, ge_iff_le, Bool.or_eq_true, decide_eq_true_eq,
_root_.not_or, Int.not_le, Int.not_lt] at h h'
simp only [bool_to_prop, saddOverflow, toInt_add, ge_iff_le, Nat.add_one_sub_one]
repeat rw [Int.bmod_eq_of_le (by push_cast; omega) (by push_cast; omega)]
omega
@[simp]
theorem shiftLeft_add_distrib {x y : BitVec w} {n : Nat} :
(x + y) <<< n = x <<< n + y <<< n := by
@ -3820,6 +3845,18 @@ theorem toNat_mul_of_not_umulOverflow {x y : BitVec w} (h : ¬ umulOverflow x y)
· simp only [umulOverflow, ge_iff_le, decide_eq_true_eq, Nat.not_le] at h
rw [toNat_mul, Nat.mod_eq_of_lt h]
/--
Unsigned multiplication overflow reassociation.
If `(x * y)` and `(y * z)` do not overflow, then `(x * y) * z` overflows iff `x * (y * z)` overflows.
-/
theorem umulOverflow_assoc {x y z : BitVec w} (h : ¬ x.umulOverflow y) (h' : ¬ y.umulOverflow z) :
(x * y).umulOverflow z = x.umulOverflow (y * z) := by
simp only [umulOverflow, ge_iff_le, decide_eq_true_eq, Nat.not_le] at h h'
simp only [umulOverflow, toNat_mul, ge_iff_le, decide_eq_decide]
repeat rw [Nat.mod_eq_of_lt (by omega)]
rw [Nat.mul_assoc]
@[simp]
theorem toInt_mul_of_not_smulOverflow {x y : BitVec w} (h : ¬ smulOverflow x y) :
(x * y).toInt = x.toInt * y.toInt := by
@ -3829,6 +3866,20 @@ theorem toInt_mul_of_not_smulOverflow {x y : BitVec w} (h : ¬ smulOverflow x y)
_root_.not_or, Int.not_le, Int.not_lt] at h
rw [toInt_mul, Int.bmod_eq_of_le (by push_cast; omega) (by push_cast; omega)]
/--
Signed multiplication overflow reassociation.
If `(x * y)` and `(y * z)` do not overflow, then `(x * y) * z` overflows iff `x * (y * z)` overflows.
-/
theorem smulOverflow_assoc {x y z : BitVec w} (h : ¬ x.smulOverflow y) (h' : ¬ y.smulOverflow z) :
(x * y).smulOverflow z = x.smulOverflow (y * z) := by
rcases w with _|w
· simp [of_length_zero]
· simp only [smulOverflow, Nat.add_one_sub_one, ge_iff_le, Bool.or_eq_true, decide_eq_true_eq,
_root_.not_or, Int.not_le, Int.not_lt] at h h'
simp only [smulOverflow, toInt_mul, Nat.add_one_sub_one, ge_iff_le, bool_to_prop]
repeat rw [Int.bmod_eq_of_le (by push_cast; omega) (by push_cast; omega)]
rw [Int.mul_assoc]
theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) =
BitVec.ofInt n x * BitVec.ofInt n y := by
apply eq_of_toInt_eq