feat: supporting Nat and BitVec material for finite types (#7598)

This PR adds miscellaneous results about `Nat` and `BitVec` that will be
required for `IntX` theory (#7592).
This commit is contained in:
Markus Himmel 2025-03-24 16:04:53 +01:00 committed by GitHub
parent 3c2d81d3c0
commit 92439acee5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 100 additions and 11 deletions

View file

@ -1400,7 +1400,7 @@ theorem not_sub_eq_not_add {x y : BitVec w} : ~~~ (x - y) = ~~~ x + y := by
/-- The value of `(carry i x y false)` can be computed by truncating `x` and `y`
to `len` bits where `len ≥ i`. -/
theorem carry_extractLsb'_eq_carry {w i len : Nat} (hi : i < len)
{x y : BitVec w} {b : Bool}:
{x y : BitVec w} {b : Bool}:
(carry i (extractLsb' 0 len x) (extractLsb' 0 len y) b)
= (carry i x y b) := by
simp only [carry, extractLsb'_toNat, shiftRight_zero, toNat_false, Nat.add_zero, ge_iff_le,
@ -1414,23 +1414,15 @@ theorem carry_extractLsb'_eq_carry {w i len : Nat} (hi : i < len)
The `[0..len)` low bits of `x + y` can be computed by truncating `x` and `y`
to `len` bits and then adding.
-/
theorem extractLsb'_add {w len : Nat} {x y : BitVec w} (hlen : len ≤ w) :
theorem extractLsb'_add {w len : Nat} {x y : BitVec w} (hlen : len ≤ w) :
(x + y).extractLsb' 0 len = x.extractLsb' 0 len + y.extractLsb' 0 len := by
ext i hi
rw [getElem_extractLsb', Nat.zero_add, getLsbD_add (by omega)]
simp [getElem_add, carry_extractLsb'_eq_carry hi, getElem_extractLsb', Nat.zero_add]
-- `setWidth` commutes with multiplication. -/
theorem setWidth_mul {w len} {x y : BitVec w} (hlen : len ≤ w) :
(x * y).setWidth len = (x.setWidth len) * (y.setWidth len) := by
apply eq_of_toNat_eq
simp only [toNat_setWidth, toNat_mul, mul_mod_mod, mod_mul_mod]
rw [Nat.mod_mod_of_dvd]
exact pow_dvd_pow_iff_le_right'.mpr hlen
/-- `extractLsb'` commutes with multiplication. -/
theorem extractLsb'_mul {w len} {x y : BitVec w} (hlen : len ≤ w) :
(x * y).extractLsb' 0 len = (x.extractLsb' 0 len) * (y.extractLsb' 0 len) := by
simp [← setWidth_eq_extractLsb' hlen, setWidth_mul hlen]
simp [← setWidth_eq_extractLsb' hlen, setWidth_mul _ _ hlen]
end BitVec

View file

@ -716,6 +716,9 @@ theorem slt_zero_eq_msb {w : Nat} {x : BitVec w} : x.slt 0#w = x.msb := by
theorem sle_iff_toInt_le {w : Nat} {b b' : BitVec w} : b.sle b' ↔ b.toInt ≤ b'.toInt :=
decide_eq_true_iff
theorem slt_iff_toInt_lt {w : Nat} {b b' : BitVec w} : b.slt b' ↔ b.toInt < b'.toInt :=
decide_eq_true_iff
/-! ### setWidth, zeroExtend and truncate -/
@[simp]
@ -1243,6 +1246,9 @@ theorem extractLsb_or {x : BitVec w} {hi lo : Nat} :
ext k hk
simp [hk, show k ≤ lo - hi by omega]
@[simp] theorem ofNat_or {x y : Nat} : BitVec.ofNat w (x ||| y) = BitVec.ofNat w x ||| BitVec.ofNat w y :=
eq_of_toNat_eq (by simp [Nat.or_mod_two_pow])
/-! ### and -/
@[simp] theorem toNat_and (x y : BitVec v) :
@ -1340,6 +1346,9 @@ theorem extractLsb_and {x : BitVec w} {hi lo : Nat} :
ext k hk
simp [hk, show k ≤ lo - hi by omega]
@[simp] theorem ofNat_and {x y : Nat} : BitVec.ofNat w (x &&& y) = BitVec.ofNat w x &&& BitVec.ofNat w y :=
eq_of_toNat_eq (by simp [Nat.and_mod_two_pow])
/-! ### xor -/
@[simp] theorem toNat_xor (x y : BitVec v) :
@ -1440,6 +1449,9 @@ theorem extractLsb_xor {x : BitVec w} {hi lo : Nat} :
ext k hk
simp [hk, show k ≤ lo - hi by omega]
@[simp] theorem ofNat_xor {x y : Nat} : BitVec.ofNat w (x ^^^ y) = BitVec.ofNat w x ^^^ BitVec.ofNat w y :=
eq_of_toNat_eq (by simp [Nat.xor_mod_two_pow])
/-! ### not -/
theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
@ -2402,6 +2414,30 @@ theorem toInt_signExtend_eq_toInt_bmod_of_le (x : BitVec w) (h : v ≤ w) :
(x.signExtend v).toInt = x.toInt.bmod (2 ^ v) := by
rw [BitVec.toInt_signExtend, Nat.min_eq_left h]
@[simp] theorem signExtend_and {x y : BitVec w} :
(x &&& y).signExtend v = (x.signExtend v) &&& (y.signExtend v) := by
refine eq_of_getElem_eq (fun i hi => ?_)
simp only [getElem_signExtend, getElem_and, msb_and]
split <;> simp
@[simp] theorem signExtend_or {x y : BitVec w} :
(x ||| y).signExtend v = (x.signExtend v) ||| (y.signExtend v) := by
refine eq_of_getElem_eq (fun i hi => ?_)
simp only [getElem_signExtend, getElem_or, msb_or]
split <;> simp
@[simp] theorem signExtend_xor {x y : BitVec w} :
(x ^^^ y).signExtend v = (x.signExtend v) ^^^ (y.signExtend v) := by
refine eq_of_getElem_eq (fun i hi => ?_)
simp only [getElem_signExtend, getElem_xor, msb_xor]
split <;> simp
@[simp] theorem signExtend_not {x : BitVec w} (h : 0 < w) :
(~~~x).signExtend v = ~~~(x.signExtend v) := by
refine eq_of_getElem_eq (fun i hi => ?_)
simp [getElem_signExtend]
split <;> simp_all
/-! ### append -/
theorem append_def (x : BitVec v) (y : BitVec w) :
@ -3463,6 +3499,11 @@ theorem neg_mul_not_eq_add_mul {x y : BitVec w} :
theorem neg_eq_neg_one_mul (b : BitVec w) : -b = -1#w * b :=
BitVec.eq_of_toInt_eq (by simp)
theorem setWidth_mul (x y : BitVec w) (h : i ≤ w) :
(x * y).setWidth i = x.setWidth i * y.setWidth i := by
have dvd : 2^i 2^w := Nat.pow_dvd_pow _ h
simp [bitvec_to_nat, h, Nat.mod_mod_of_dvd _ dvd]
/-! ### le and lt -/
@[bitvec_to_nat] theorem le_def {x y : BitVec n} :
@ -4624,6 +4665,14 @@ theorem getLsbD_intMax (w : Nat) : (intMax w).getLsbD i = decide (i + 1 < w) :=
· simp [h]
· rw [Nat.sub_add_cancel (Nat.two_pow_pos (w - 1)), Nat.two_pow_pred_mod_two_pow (by omega)]
@[simp] theorem toInt_intMax : (BitVec.intMax w).toInt = 2 ^ (w - 1) - 1 := by
refine (Nat.eq_zero_or_pos w).elim (by rintro rfl; simp [BitVec.toInt_of_zero_length]) (fun hw => ?_)
rw [BitVec.toInt, toNat_intMax, if_pos]
· rw [Int.ofNat_sub Nat.one_le_two_pow, Int.natCast_pow, Int.cast_ofNat_Int, Int.cast_ofNat_Int]
· rw [Nat.mul_sub_left_distrib, ← Nat.pow_succ', Nat.succ_eq_add_one, Nat.sub_add_cancel hw]
apply Nat.sub_lt_self (by decide)
rw [Nat.mul_one]
apply Nat.le_pow hw
/-! ### Non-overflow theorems -/

View file

@ -286,6 +286,9 @@ theorem sub_one_lt : ∀ {n : Nat}, n ≠ 0 → n - 1 < n := pred_lt
| zero => exact Nat.le_refl (n - 0)
| succ m ih => apply Nat.le_trans (pred_le (n - m)) ih
theorem sub_lt_of_lt {a b c : Nat} (h : a < c) : a - b < c :=
Nat.lt_of_le_of_lt (Nat.sub_le _ _) h
theorem sub_lt : ∀ {n m : Nat}, 0 < n → 0 < m → n - m < n
| 0, _, h1, _ => absurd h1 (Nat.lt_irrefl 0)
| _+1, 0, _, h2 => absurd h2 (Nat.lt_irrefl 0)
@ -413,6 +416,9 @@ protected theorem lt_add_right (c : Nat) (h : a < b) : a < b + c :=
theorem lt_of_add_right_lt {n m k : Nat} (h : n + k < m) : n < m :=
Nat.lt_of_le_of_lt (Nat.le_add_right ..) h
theorem lt_of_add_left_lt {n m k : Nat} (h : k + n < m) : n < m :=
lt_of_add_right_lt (Nat.add_comm _ _ ▸ h)
theorem le.dest : ∀ {n m : Nat}, n ≤ m → Exists (fun k => n + k = m)
| zero, zero, _ => ⟨0, rfl⟩
| zero, succ n, _ => ⟨succ n, Nat.add_comm 0 (succ n) ▸ rfl⟩

View file

@ -460,6 +460,14 @@ theorem bitwise_div_two_pow (of_false_false : f false false = false := by rfl) :
apply Nat.eq_of_testBit_eq
simp [testBit_bitwise of_false_false, testBit_div_two_pow]
theorem bitwise_mod_two_pow (of_false_false : f false false = false := by rfl) :
(bitwise f x y) % 2 ^ n = bitwise f (x % 2 ^ n) (y % 2 ^ n) := by
apply Nat.eq_of_testBit_eq
simp only [testBit_mod_two_pow, testBit_bitwise of_false_false]
intro i
by_cases h : i < n <;> simp only [h, decide_true, decide_false, Bool.true_and, Bool.false_and,
of_false_false]
/-! ### and -/
@[simp] theorem testBit_and (x y i : Nat) : (x &&& y).testBit i = (x.testBit i && y.testBit i) := by
@ -527,6 +535,9 @@ theorem and_div_two_pow : (a &&& b) / 2 ^ n = a / 2 ^ n &&& b / 2 ^ n :=
theorem and_div_two : (a &&& b) / 2 = a / 2 &&& b / 2 :=
and_div_two_pow (n := 1)
theorem and_mod_two_pow : (a &&& b) % 2 ^ n = (a % 2 ^ n) &&& (b % 2 ^ n) :=
bitwise_mod_two_pow
/-! ### lor -/
@[simp] theorem zero_or (x : Nat) : 0 ||| x = x := by
@ -597,6 +608,9 @@ theorem or_div_two_pow : (a ||| b) / 2 ^ n = a / 2 ^ n ||| b / 2 ^ n :=
theorem or_div_two : (a ||| b) / 2 = a / 2 ||| b / 2 :=
or_div_two_pow (n := 1)
theorem or_mod_two_pow : (a ||| b) % 2 ^ n = a % 2 ^ n ||| b % 2 ^ n :=
bitwise_mod_two_pow
/-! ### xor -/
@[simp] theorem testBit_xor (x y i : Nat) :
@ -655,6 +669,9 @@ theorem xor_div_two_pow : (a ^^^ b) / 2 ^ n = a / 2 ^ n ^^^ b / 2 ^ n :=
theorem xor_div_two : (a ^^^ b) / 2 = a / 2 ^^^ b / 2 :=
xor_div_two_pow (n := 1)
theorem xor_mod_two_pow : (a ^^^ b) % 2 ^ n = a % 2 ^ n ^^^ b % 2 ^ n :=
bitwise_mod_two_pow
/-! ### Arithmetic -/
theorem testBit_two_pow_mul_add (a : Nat) {b i : Nat} (b_lt : b < 2^i) (j : Nat) :
@ -774,6 +791,20 @@ theorem shiftRight_or_distrib {a b : Nat} : (a ||| b) >>> i = a >>> i ||| b >>>
theorem shiftRight_xor_distrib {a b : Nat} : (a ^^^ b) >>> i = a >>> i ^^^ b >>> i :=
shiftRight_bitwise_distrib
theorem mod_two_pow_shiftLeft_mod_two_pow {a b c : Nat} : ((a % 2 ^ c) <<< b) % 2 ^ c = (a <<< b) % 2 ^ c := by
apply Nat.eq_of_testBit_eq
simp only [testBit_mod_two_pow, testBit_shiftLeft, ge_iff_le]
intro i
by_cases hic : i < c
· simp [(by omega : i - b < c)]
· simp [*]
theorem le_shiftLeft {a b : Nat} : a ≤ a <<< b :=
shiftLeft_eq _ _ ▸ Nat.le_mul_of_pos_right _ (Nat.two_pow_pos _)
theorem lt_of_shiftLeft_lt {a b c : Nat} (h : a <<< b < c) : a < c :=
Nat.lt_of_le_of_lt le_shiftLeft h
/-! ### le -/
theorem le_of_testBit {n m : Nat} (h : ∀ i, n.testBit i = true → m.testBit i = true) : n ≤ m := by

View file

@ -320,6 +320,9 @@ theorem mod_le (x y : Nat) : x % y ≤ x := by
| Or.inl h₂ => rw [h₂, Nat.mod_zero x]; apply Nat.le_refl
| Or.inr h₂ => exact Nat.le_trans (Nat.le_of_lt (mod_lt _ h₂)) h₁
theorem mod_lt_of_lt {a b c : Nat} (h : a < c) : a % b < c :=
Nat.lt_of_le_of_lt (Nat.mod_le _ _) h
@[simp] theorem zero_mod (b : Nat) : 0 % b = 0 := by
rw [mod_eq]
have : ¬ (0 < b ∧ b = 0) := by

View file

@ -778,6 +778,11 @@ theorem two_pow_pred_mul_two (h : 0 < w) :
2 ^ (w - 1) * 2 = 2 ^ w := by
simp [← Nat.pow_succ, Nat.sub_add_cancel h]
theorem le_pow {a b : Nat} (h : 0 < b) : a ≤ a ^ b := by
refine (eq_zero_or_pos a).elim (by rintro rfl; simp) (fun ha => ?_)
rw [(show b = b - 1 + 1 by omega), Nat.pow_succ]
exact Nat.le_mul_of_pos_left _ (Nat.pow_pos ha)
/-! ### log2 -/
@[simp]

View file

@ -74,5 +74,8 @@ theorem numBits_le : numBits ≤ 64 := by
@[simp] theorem System.Platform.numBits_dvd : System.Platform.numBits 64 :=
numBits_eq.elim (fun h => ⟨2, h ▸ rfl⟩) (fun h => ⟨1, by simp [h]⟩)
instance : NeZero System.Platform.numBits where
out := Nat.ne_zero_of_lt System.Platform.numBits_pos
end Platform
end System