feat: commute BitVec.extractLsb(')? with bitwise ops (#6747)

This PR adds the ability to push `BitVec.extractLsb` and
`BitVec.extractLsb'` with bitwise operations. This is useful for
constant-folding extracts.
This commit is contained in:
Siddharth 2025-01-24 15:23:30 +00:00 committed by GitHub
parent 1059e25ca2
commit 044bf85fe9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -905,6 +905,16 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· ||| · ) (0#n) where
ext i h
simp [h]
theorem extractLsb'_or {x y : BitVec w} {start len : Nat} :
(x ||| y).extractLsb' start len = (x.extractLsb' start len) ||| (y.extractLsb' start len) := by
ext i hi
simp [hi]
theorem extractLsb_or {x : BitVec w} {hi lo : Nat} :
(x ||| y).extractLsb lo hi = (x.extractLsb lo hi) ||| (y.extractLsb lo hi) := by
ext k hk
simp [hk, show k ≤ lo - hi by omega]
/-! ### and -/
@[simp] theorem toNat_and (x y : BitVec v) :
@ -978,6 +988,16 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· &&& · ) (allOnes n) wher
ext i h
simp [h]
theorem extractLsb'_and {x y : BitVec w} {start len : Nat} :
(x &&& y).extractLsb' start len = (x.extractLsb' start len) &&& (y.extractLsb' start len) := by
ext i hi
simp [hi]
theorem extractLsb_and {x : BitVec w} {hi lo : Nat} :
(x &&& y).extractLsb lo hi = (x.extractLsb lo hi) &&& (y.extractLsb lo hi) := by
ext k hk
simp [hk, show k ≤ lo - hi by omega]
/-! ### xor -/
@[simp] theorem toNat_xor (x y : BitVec v) :
@ -1043,6 +1063,16 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· ^^^ · ) (0#n) where
ext i
simp
theorem extractLsb'_xor {x y : BitVec w} {start len : Nat} :
(x ^^^ y).extractLsb' start len = (x.extractLsb' start len) ^^^ (y.extractLsb' start len) := by
ext i hi
simp [hi]
theorem extractLsb_xor {x : BitVec w} {hi lo : Nat} :
(x ^^^ y).extractLsb lo hi = (x.extractLsb lo hi) ^^^ (y.extractLsb lo hi) := by
ext k hk
simp [hk, show k ≤ lo - hi by omega]
/-! ### not -/
theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
@ -1149,6 +1179,31 @@ theorem getMsb_not {x : BitVec w} :
@[simp] theorem msb_not {x : BitVec w} : (~~~x).msb = (decide (0 < w) && !x.msb) := by
simp [BitVec.msb]
/--
Negating `x` and then extracting [start..start+len) is the same as extracting and then negating,
as long as the range [start..start+len) is in bounds.
See that if the index is out-of-bounds, then `extractLsb` will return `false`,
which makes the operation not commute.
-/
theorem extractLsb'_not_of_lt {x : BitVec w} {start len : Nat} (h : start + len < w) :
(~~~ x).extractLsb' start len = ~~~ (x.extractLsb' start len) := by
ext i hi
simp [hi]
omega
/--
Negating `x` and then extracting [lo:hi] is the same as extracting and then negating.
For the extraction to be well-behaved,
we need the range [lo:hi] to be a valid closed interval inside the bitvector:
1. `lo ≤ hi` for the interval to be a well-formed closed interval.
2. `hi < w`, for the interval to be contained inside the bitvector.
-/
theorem extractLsb_not_of_lt {x : BitVec w} {hi lo : Nat} (hlo : lo ≤ hi) (hhi : hi < w) :
(~~~ x).extractLsb hi lo = ~~~ (x.extractLsb hi lo) := by
ext k hk
simp [hk, show k ≤ hi - lo by omega]
omega
/-! ### cast -/
@[simp] theorem not_cast {x : BitVec w} (h : w = w') : ~~~(x.cast h) = (~~~x).cast h := by