From 92439acee533acf806ae6b9aaed41b3e8a6319bb Mon Sep 17 00:00:00 2001 From: Markus Himmel Date: Mon, 24 Mar 2025 16:04:53 +0100 Subject: [PATCH] feat: supporting `Nat` and `BitVec` material for finite types (#7598) This PR adds miscellaneous results about `Nat` and `BitVec` that will be required for `IntX` theory (#7592). --- src/Init/Data/BitVec/Bitblast.lean | 14 ++------ src/Init/Data/BitVec/Lemmas.lean | 49 +++++++++++++++++++++++++++ src/Init/Data/Nat/Basic.lean | 6 ++++ src/Init/Data/Nat/Bitwise/Lemmas.lean | 31 +++++++++++++++++ src/Init/Data/Nat/Div/Basic.lean | 3 ++ src/Init/Data/Nat/Lemmas.lean | 5 +++ src/Init/System/Platform.lean | 3 ++ 7 files changed, 100 insertions(+), 11 deletions(-) diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index aadc5e03f7..d77e3cb6d5 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -1400,7 +1400,7 @@ theorem not_sub_eq_not_add {x y : BitVec w} : ~~~ (x - y) = ~~~ x + y := by /-- The value of `(carry i x y false)` can be computed by truncating `x` and `y` to `len` bits where `len ≥ i`. -/ theorem carry_extractLsb'_eq_carry {w i len : Nat} (hi : i < len) - {x y : BitVec w} {b : Bool}: + {x y : BitVec w} {b : Bool}: (carry i (extractLsb' 0 len x) (extractLsb' 0 len y) b) = (carry i x y b) := by simp only [carry, extractLsb'_toNat, shiftRight_zero, toNat_false, Nat.add_zero, ge_iff_le, @@ -1414,23 +1414,15 @@ theorem carry_extractLsb'_eq_carry {w i len : Nat} (hi : i < len) The `[0..len)` low bits of `x + y` can be computed by truncating `x` and `y` to `len` bits and then adding. -/ -theorem extractLsb'_add {w len : Nat} {x y : BitVec w} (hlen : len ≤ w) : +theorem extractLsb'_add {w len : Nat} {x y : BitVec w} (hlen : len ≤ w) : (x + y).extractLsb' 0 len = x.extractLsb' 0 len + y.extractLsb' 0 len := by ext i hi rw [getElem_extractLsb', Nat.zero_add, getLsbD_add (by omega)] simp [getElem_add, carry_extractLsb'_eq_carry hi, getElem_extractLsb', Nat.zero_add] --- `setWidth` commutes with multiplication. -/ -theorem setWidth_mul {w len} {x y : BitVec w} (hlen : len ≤ w) : - (x * y).setWidth len = (x.setWidth len) * (y.setWidth len) := by - apply eq_of_toNat_eq - simp only [toNat_setWidth, toNat_mul, mul_mod_mod, mod_mul_mod] - rw [Nat.mod_mod_of_dvd] - exact pow_dvd_pow_iff_le_right'.mpr hlen - /-- `extractLsb'` commutes with multiplication. -/ theorem extractLsb'_mul {w len} {x y : BitVec w} (hlen : len ≤ w) : (x * y).extractLsb' 0 len = (x.extractLsb' 0 len) * (y.extractLsb' 0 len) := by - simp [← setWidth_eq_extractLsb' hlen, setWidth_mul hlen] + simp [← setWidth_eq_extractLsb' hlen, setWidth_mul _ _ hlen] end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index a4eef825af..a2d127b173 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -716,6 +716,9 @@ theorem slt_zero_eq_msb {w : Nat} {x : BitVec w} : x.slt 0#w = x.msb := by theorem sle_iff_toInt_le {w : Nat} {b b' : BitVec w} : b.sle b' ↔ b.toInt ≤ b'.toInt := decide_eq_true_iff +theorem slt_iff_toInt_lt {w : Nat} {b b' : BitVec w} : b.slt b' ↔ b.toInt < b'.toInt := + decide_eq_true_iff + /-! ### setWidth, zeroExtend and truncate -/ @[simp] @@ -1243,6 +1246,9 @@ theorem extractLsb_or {x : BitVec w} {hi lo : Nat} : ext k hk simp [hk, show k ≤ lo - hi by omega] +@[simp] theorem ofNat_or {x y : Nat} : BitVec.ofNat w (x ||| y) = BitVec.ofNat w x ||| BitVec.ofNat w y := + eq_of_toNat_eq (by simp [Nat.or_mod_two_pow]) + /-! ### and -/ @[simp] theorem toNat_and (x y : BitVec v) : @@ -1340,6 +1346,9 @@ theorem extractLsb_and {x : BitVec w} {hi lo : Nat} : ext k hk simp [hk, show k ≤ lo - hi by omega] +@[simp] theorem ofNat_and {x y : Nat} : BitVec.ofNat w (x &&& y) = BitVec.ofNat w x &&& BitVec.ofNat w y := + eq_of_toNat_eq (by simp [Nat.and_mod_two_pow]) + /-! ### xor -/ @[simp] theorem toNat_xor (x y : BitVec v) : @@ -1440,6 +1449,9 @@ theorem extractLsb_xor {x : BitVec w} {hi lo : Nat} : ext k hk simp [hk, show k ≤ lo - hi by omega] +@[simp] theorem ofNat_xor {x y : Nat} : BitVec.ofNat w (x ^^^ y) = BitVec.ofNat w x ^^^ BitVec.ofNat w y := + eq_of_toNat_eq (by simp [Nat.xor_mod_two_pow]) + /-! ### not -/ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl @@ -2402,6 +2414,30 @@ theorem toInt_signExtend_eq_toInt_bmod_of_le (x : BitVec w) (h : v ≤ w) : (x.signExtend v).toInt = x.toInt.bmod (2 ^ v) := by rw [BitVec.toInt_signExtend, Nat.min_eq_left h] +@[simp] theorem signExtend_and {x y : BitVec w} : + (x &&& y).signExtend v = (x.signExtend v) &&& (y.signExtend v) := by + refine eq_of_getElem_eq (fun i hi => ?_) + simp only [getElem_signExtend, getElem_and, msb_and] + split <;> simp + +@[simp] theorem signExtend_or {x y : BitVec w} : + (x ||| y).signExtend v = (x.signExtend v) ||| (y.signExtend v) := by + refine eq_of_getElem_eq (fun i hi => ?_) + simp only [getElem_signExtend, getElem_or, msb_or] + split <;> simp + +@[simp] theorem signExtend_xor {x y : BitVec w} : + (x ^^^ y).signExtend v = (x.signExtend v) ^^^ (y.signExtend v) := by + refine eq_of_getElem_eq (fun i hi => ?_) + simp only [getElem_signExtend, getElem_xor, msb_xor] + split <;> simp + +@[simp] theorem signExtend_not {x : BitVec w} (h : 0 < w) : + (~~~x).signExtend v = ~~~(x.signExtend v) := by + refine eq_of_getElem_eq (fun i hi => ?_) + simp [getElem_signExtend] + split <;> simp_all + /-! ### append -/ theorem append_def (x : BitVec v) (y : BitVec w) : @@ -3463,6 +3499,11 @@ theorem neg_mul_not_eq_add_mul {x y : BitVec w} : theorem neg_eq_neg_one_mul (b : BitVec w) : -b = -1#w * b := BitVec.eq_of_toInt_eq (by simp) +theorem setWidth_mul (x y : BitVec w) (h : i ≤ w) : + (x * y).setWidth i = x.setWidth i * y.setWidth i := by + have dvd : 2^i ∣ 2^w := Nat.pow_dvd_pow _ h + simp [bitvec_to_nat, h, Nat.mod_mod_of_dvd _ dvd] + /-! ### le and lt -/ @[bitvec_to_nat] theorem le_def {x y : BitVec n} : @@ -4624,6 +4665,14 @@ theorem getLsbD_intMax (w : Nat) : (intMax w).getLsbD i = decide (i + 1 < w) := · simp [h] · rw [Nat.sub_add_cancel (Nat.two_pow_pos (w - 1)), Nat.two_pow_pred_mod_two_pow (by omega)] +@[simp] theorem toInt_intMax : (BitVec.intMax w).toInt = 2 ^ (w - 1) - 1 := by + refine (Nat.eq_zero_or_pos w).elim (by rintro rfl; simp [BitVec.toInt_of_zero_length]) (fun hw => ?_) + rw [BitVec.toInt, toNat_intMax, if_pos] + · rw [Int.ofNat_sub Nat.one_le_two_pow, Int.natCast_pow, Int.cast_ofNat_Int, Int.cast_ofNat_Int] + · rw [Nat.mul_sub_left_distrib, ← Nat.pow_succ', Nat.succ_eq_add_one, Nat.sub_add_cancel hw] + apply Nat.sub_lt_self (by decide) + rw [Nat.mul_one] + apply Nat.le_pow hw /-! ### Non-overflow theorems -/ diff --git a/src/Init/Data/Nat/Basic.lean b/src/Init/Data/Nat/Basic.lean index 47c0f6528c..d06bce9fee 100644 --- a/src/Init/Data/Nat/Basic.lean +++ b/src/Init/Data/Nat/Basic.lean @@ -286,6 +286,9 @@ theorem sub_one_lt : ∀ {n : Nat}, n ≠ 0 → n - 1 < n := pred_lt | zero => exact Nat.le_refl (n - 0) | succ m ih => apply Nat.le_trans (pred_le (n - m)) ih +theorem sub_lt_of_lt {a b c : Nat} (h : a < c) : a - b < c := + Nat.lt_of_le_of_lt (Nat.sub_le _ _) h + theorem sub_lt : ∀ {n m : Nat}, 0 < n → 0 < m → n - m < n | 0, _, h1, _ => absurd h1 (Nat.lt_irrefl 0) | _+1, 0, _, h2 => absurd h2 (Nat.lt_irrefl 0) @@ -413,6 +416,9 @@ protected theorem lt_add_right (c : Nat) (h : a < b) : a < b + c := theorem lt_of_add_right_lt {n m k : Nat} (h : n + k < m) : n < m := Nat.lt_of_le_of_lt (Nat.le_add_right ..) h +theorem lt_of_add_left_lt {n m k : Nat} (h : k + n < m) : n < m := + lt_of_add_right_lt (Nat.add_comm _ _ ▸ h) + theorem le.dest : ∀ {n m : Nat}, n ≤ m → Exists (fun k => n + k = m) | zero, zero, _ => ⟨0, rfl⟩ | zero, succ n, _ => ⟨succ n, Nat.add_comm 0 (succ n) ▸ rfl⟩ diff --git a/src/Init/Data/Nat/Bitwise/Lemmas.lean b/src/Init/Data/Nat/Bitwise/Lemmas.lean index 852939e129..8e46e450ec 100644 --- a/src/Init/Data/Nat/Bitwise/Lemmas.lean +++ b/src/Init/Data/Nat/Bitwise/Lemmas.lean @@ -460,6 +460,14 @@ theorem bitwise_div_two_pow (of_false_false : f false false = false := by rfl) : apply Nat.eq_of_testBit_eq simp [testBit_bitwise of_false_false, testBit_div_two_pow] +theorem bitwise_mod_two_pow (of_false_false : f false false = false := by rfl) : + (bitwise f x y) % 2 ^ n = bitwise f (x % 2 ^ n) (y % 2 ^ n) := by + apply Nat.eq_of_testBit_eq + simp only [testBit_mod_two_pow, testBit_bitwise of_false_false] + intro i + by_cases h : i < n <;> simp only [h, decide_true, decide_false, Bool.true_and, Bool.false_and, + of_false_false] + /-! ### and -/ @[simp] theorem testBit_and (x y i : Nat) : (x &&& y).testBit i = (x.testBit i && y.testBit i) := by @@ -527,6 +535,9 @@ theorem and_div_two_pow : (a &&& b) / 2 ^ n = a / 2 ^ n &&& b / 2 ^ n := theorem and_div_two : (a &&& b) / 2 = a / 2 &&& b / 2 := and_div_two_pow (n := 1) +theorem and_mod_two_pow : (a &&& b) % 2 ^ n = (a % 2 ^ n) &&& (b % 2 ^ n) := + bitwise_mod_two_pow + /-! ### lor -/ @[simp] theorem zero_or (x : Nat) : 0 ||| x = x := by @@ -597,6 +608,9 @@ theorem or_div_two_pow : (a ||| b) / 2 ^ n = a / 2 ^ n ||| b / 2 ^ n := theorem or_div_two : (a ||| b) / 2 = a / 2 ||| b / 2 := or_div_two_pow (n := 1) +theorem or_mod_two_pow : (a ||| b) % 2 ^ n = a % 2 ^ n ||| b % 2 ^ n := + bitwise_mod_two_pow + /-! ### xor -/ @[simp] theorem testBit_xor (x y i : Nat) : @@ -655,6 +669,9 @@ theorem xor_div_two_pow : (a ^^^ b) / 2 ^ n = a / 2 ^ n ^^^ b / 2 ^ n := theorem xor_div_two : (a ^^^ b) / 2 = a / 2 ^^^ b / 2 := xor_div_two_pow (n := 1) +theorem xor_mod_two_pow : (a ^^^ b) % 2 ^ n = a % 2 ^ n ^^^ b % 2 ^ n := + bitwise_mod_two_pow + /-! ### Arithmetic -/ theorem testBit_two_pow_mul_add (a : Nat) {b i : Nat} (b_lt : b < 2^i) (j : Nat) : @@ -774,6 +791,20 @@ theorem shiftRight_or_distrib {a b : Nat} : (a ||| b) >>> i = a >>> i ||| b >>> theorem shiftRight_xor_distrib {a b : Nat} : (a ^^^ b) >>> i = a >>> i ^^^ b >>> i := shiftRight_bitwise_distrib +theorem mod_two_pow_shiftLeft_mod_two_pow {a b c : Nat} : ((a % 2 ^ c) <<< b) % 2 ^ c = (a <<< b) % 2 ^ c := by + apply Nat.eq_of_testBit_eq + simp only [testBit_mod_two_pow, testBit_shiftLeft, ge_iff_le] + intro i + by_cases hic : i < c + · simp [(by omega : i - b < c)] + · simp [*] + +theorem le_shiftLeft {a b : Nat} : a ≤ a <<< b := + shiftLeft_eq _ _ ▸ Nat.le_mul_of_pos_right _ (Nat.two_pow_pos _) + +theorem lt_of_shiftLeft_lt {a b c : Nat} (h : a <<< b < c) : a < c := + Nat.lt_of_le_of_lt le_shiftLeft h + /-! ### le -/ theorem le_of_testBit {n m : Nat} (h : ∀ i, n.testBit i = true → m.testBit i = true) : n ≤ m := by diff --git a/src/Init/Data/Nat/Div/Basic.lean b/src/Init/Data/Nat/Div/Basic.lean index 311039aac1..5b76ca844b 100644 --- a/src/Init/Data/Nat/Div/Basic.lean +++ b/src/Init/Data/Nat/Div/Basic.lean @@ -320,6 +320,9 @@ theorem mod_le (x y : Nat) : x % y ≤ x := by | Or.inl h₂ => rw [h₂, Nat.mod_zero x]; apply Nat.le_refl | Or.inr h₂ => exact Nat.le_trans (Nat.le_of_lt (mod_lt _ h₂)) h₁ +theorem mod_lt_of_lt {a b c : Nat} (h : a < c) : a % b < c := + Nat.lt_of_le_of_lt (Nat.mod_le _ _) h + @[simp] theorem zero_mod (b : Nat) : 0 % b = 0 := by rw [mod_eq] have : ¬ (0 < b ∧ b = 0) := by diff --git a/src/Init/Data/Nat/Lemmas.lean b/src/Init/Data/Nat/Lemmas.lean index c1654c2faf..ddb10fabb6 100644 --- a/src/Init/Data/Nat/Lemmas.lean +++ b/src/Init/Data/Nat/Lemmas.lean @@ -778,6 +778,11 @@ theorem two_pow_pred_mul_two (h : 0 < w) : 2 ^ (w - 1) * 2 = 2 ^ w := by simp [← Nat.pow_succ, Nat.sub_add_cancel h] +theorem le_pow {a b : Nat} (h : 0 < b) : a ≤ a ^ b := by + refine (eq_zero_or_pos a).elim (by rintro rfl; simp) (fun ha => ?_) + rw [(show b = b - 1 + 1 by omega), Nat.pow_succ] + exact Nat.le_mul_of_pos_left _ (Nat.pow_pos ha) + /-! ### log2 -/ @[simp] diff --git a/src/Init/System/Platform.lean b/src/Init/System/Platform.lean index ca2f575c99..ac2640aa6e 100644 --- a/src/Init/System/Platform.lean +++ b/src/Init/System/Platform.lean @@ -74,5 +74,8 @@ theorem numBits_le : numBits ≤ 64 := by @[simp] theorem System.Platform.numBits_dvd : System.Platform.numBits ∣ 64 := numBits_eq.elim (fun h => ⟨2, h ▸ rfl⟩) (fun h => ⟨1, by simp [h]⟩) +instance : NeZero System.Platform.numBits where + out := Nat.ne_zero_of_lt System.Platform.numBits_pos + end Platform end System