diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 308b4cf60a..0f5ecc441e 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -874,4 +874,7 @@ def clzAuxRec {w : Nat} (x : BitVec w) (n : Nat) : BitVec w := /-- Count the number of leading zeros. -/ def clz (x : BitVec w) : BitVec w := clzAuxRec x (w - 1) +/-- Count the number of trailing zeros. -/ +def ctz (x : BitVec w) : BitVec w := (x.reverse).clz + end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index d0b61a84ec..11898eea0a 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -5779,6 +5779,25 @@ theorem msb_replicate {n w : Nat} {x : BitVec w} : simp only [BitVec.msb, getMsbD_replicate, Nat.zero_mod] cases n <;> cases w <;> simp +@[simp] +theorem reverse_eq_zero_iff {x : BitVec w} : + x.reverse = 0#w ↔ x = 0#w := by + constructor + · intro hrev + ext i hi + rw [← getLsbD_eq_getElem, getLsbD_eq_getMsbD, ← getLsbD_reverse] + simp [hrev] + · intro hzero + ext i hi + rw [← getLsbD_eq_getElem, getLsbD_eq_getMsbD, getMsbD_reverse] + simp [hi, hzero] + +@[simp] +theorem reverse_reverse_eq {x : BitVec w} : + x.reverse.reverse = x := by + ext k hk + rw [getElem_reverse, getMsbD_reverse, getLsbD_eq_getElem] + /-! ### Inequalities (le / lt) -/ theorem ule_eq_not_ult (x y : BitVec w) : x.ule y = !y.ult x := by @@ -6182,6 +6201,71 @@ theorem toNat_lt_two_pow_sub_clz {x : BitVec w} : · simp [show w + 1 ≤ i by omega] · simp; omega +theorem clz_eq_reverse_ctz {x : BitVec w} : + x.clz = (x.reverse).ctz := by + simp [ctz] + +/-! ### Count trailing zeros -/ + +theorem ctz_eq_reverse_clz {x : BitVec w} : + x.ctz = (x.reverse).clz := by + simp [ctz] + +/-- The number of trailing zeroes is strictly less than the bitwidth iff the bitvector is nonzero. -/ +@[simp] +theorem ctz_lt_iff_ne_zero {x : BitVec w} : + ctz x < w ↔ x ≠ 0#w := by + simp only [ctz_eq_reverse_clz, natCast_eq_ofNat, ne_eq] + rw [show BitVec.ofNat w w = w by simp, ← reverse_eq_zero_iff (x := x)] + apply clz_lt_iff_ne_zero (x := x.reverse) + +/-- If a bitvec is different than zero the bits at indexes lower than `ctz x` are false. -/ +theorem getLsbD_false_of_lt_ctz {x : BitVec w} (hi : i < x.ctz.toNat) : + x.getLsbD i = false := by + rw [getLsbD_eq_getMsbD, ← getLsbD_reverse] + have hiff := ctz_lt_iff_ne_zero (x := x) + by_cases hzero : x = 0#w + · simp [hzero, getLsbD_reverse] + · simp only [ctz_eq_reverse_clz, natCast_eq_ofNat, ne_eq, hzero, not_false_eq_true, + iff_true] at hiff + simp only [ctz] at hi + have hi' : i < w := by simp [BitVec.lt_def] at hiff; omega + simp only [hi', decide_true, Bool.true_and] + have : (x.reverse.clzAuxRec (w - 1)).toNat ≤ w := by + rw [show ((x.reverse.clzAuxRec (w - 1)).toNat ≤ w) = + ((x.reverse.clzAuxRec (w - 1)).toNat ≤ (BitVec.ofNat w w).toNat) by simp, ← le_def] + apply clzAuxRec_le (x := x.reverse) (n := w - 1) + let j := (x.reverse.clzAuxRec (w - 1)).toNat - 1 - i + rw [show w - 1 - i = w - (x.reverse.clzAuxRec (w - 1)).toNat + j by + subst j + rw [Nat.sub_sub (n := (x.reverse.clzAuxRec (w - 1)).toNat), + ← Nat.add_sub_assoc (by exact Nat.one_add_le_iff.mpr hi)] + omega] + have hfalse : ∀ (i : Nat), w - 1 < i → x.reverse.getLsbD i = false := by + intros i hj + simp [show w ≤ i by omega] + exact getLsbD_false_of_clzAuxRec (x := x.reverse) (n := w - 1) hfalse (j := j) + +/-- If a bitvec is different than zero, the bit at index `ctz x`, i.e., the first bit after the + trailing zeros, is true. -/ +theorem getLsbD_true_ctz_of_ne_zero {x : BitVec w} (hx : x ≠ 0#w) : + x.getLsbD (ctz x).toNat = true := by + simp only [ctz_eq_reverse_clz, clz] + rw [getLsbD_eq_getMsbD, ← getLsbD_reverse] + have := ctz_lt_iff_ne_zero (x := x) + simp only [ctz_eq_reverse_clz, clz, natCast_eq_ofNat, lt_def, toNat_ofNat, Nat.mod_two_pow_self, + ne_eq] at this + simp only [this, hx, not_false_eq_true, decide_true, Bool.true_and] + have hnotrev : ¬x.reverse = 0#w := by simp [reverse_eq_zero_iff, hx] + apply getLsbD_true_of_eq_clzAuxRec_of_ne_zero (x := x.reverse) (n := w - 1) hnotrev + intro i hi + simp [show w ≤ i by omega] + +/-- A nonzero bitvector is lower-bounded by its trailing zeroes. -/ +theorem two_pow_ctz_le_toNat_of_ne_zero {x : BitVec w} (hx : x ≠ 0#w) : + 2 ^ (ctz x).toNat ≤ x.toNat := by + have hclz := getLsbD_true_ctz_of_ne_zero (x := x) hx + exact Nat.ge_two_pow_of_testBit hclz /-! ### Deprecations -/ diff --git a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean index d1938b204d..f3953a21ad 100644 --- a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean +++ b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean @@ -358,6 +358,7 @@ attribute [bv_normalize] BitVec.smulOverflow_eq attribute [bv_normalize] BitVec.usubOverflow_eq attribute [bv_normalize] BitVec.ssubOverflow_eq attribute [bv_normalize] BitVec.sdivOverflow_eq +attribute [bv_normalize] BitVec.ctz attribute [bv_normalize] BitVec.append_zero_add_zero_append diff --git a/tests/lean/run/bv_decide_rewriter.lean b/tests/lean/run/bv_decide_rewriter.lean index 8f4022647f..081a6f2bff 100644 --- a/tests/lean/run/bv_decide_rewriter.lean +++ b/tests/lean/run/bv_decide_rewriter.lean @@ -674,6 +674,11 @@ example {x : BitVec 8} (h : ¬ x = 0#8) : (x >>> 1).clz = x.clz + 1 := by bv_dec example {x y : BitVec 8} : x.clz < y.clz → y < x := by bv_decide example {x : BitVec 8} : x.clz ≤ 8 := by bv_decide +-- CTZ +example {x : BitVec 8} (h : x = 0#8) : x.ctz = x.clz := by bv_decide +example {x : BitVec 8} (h : ¬ x = 0#8) : (x <<< 1).ctz = x.ctz + 1 := by bv_decide +example {x : BitVec 8} : x.ctz ≤ 8 := by bv_decide + section namespace NormalizeMul