diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index deee66a751..eb2e202331 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -89,6 +89,9 @@ theorem eq_of_toFin_eq : ∀ {x y : BitVec w}, x.toFin = y.toFin → x = y @[simp] theorem toNat_ofBool (b : Bool) : (ofBool b).toNat = b.toNat := by cases b <;> rfl +@[simp] theorem msb_ofBool (b : Bool) : (ofBool b).msb = b := by + cases b <;> simp [BitVec.msb] + theorem ofNat_one (n : Nat) : BitVec.ofNat 1 n = BitVec.ofBool (n % 2 = 1) := by rcases (Nat.mod_two_eq_zero_or_one n) with h | h <;> simp [h, BitVec.ofNat, Fin.ofNat'] @@ -116,6 +119,8 @@ theorem getLsb_ofNat (n : Nat) (x : Nat) (i : Nat) : @[simp] theorem getLsb_zero : (0#w).getLsb i = false := by simp [getLsb] +@[simp] theorem getMsb_zero : (0#w).getMsb i = false := by simp [getMsb] + @[simp] theorem toNat_mod_cancel (x : BitVec n) : x.toNat % (2^n) = x.toNat := Nat.mod_eq_of_lt x.isLt @@ -241,6 +246,12 @@ theorem toInt_ofNat {n : Nat} (x : Nat) : else simp [n_le_i, toNat_ofNat] +theorem zeroExtend'_eq {x : BitVec w} (h : w ≤ v) : x.zeroExtend' h = x.zeroExtend v := by + apply eq_of_toNat_eq + rw [toNat_zeroExtend, toNat_zeroExtend'] + rw [Nat.mod_eq_of_lt] + exact Nat.lt_of_lt_of_le x.isLt (Nat.pow_le_pow_right (Nat.zero_lt_two) h) + @[simp, bv_toNat] theorem toNat_truncate (x : BitVec n) : (truncate i x).toNat = x.toNat % 2^i := toNat_zeroExtend i x @@ -285,6 +296,18 @@ theorem nat_eq_toNat (x : BitVec w) (y : Nat) getLsb (zeroExtend m x) i = (decide (i < m) && getLsb x i) := by simp [getLsb, toNat_zeroExtend, Nat.testBit_mod_two_pow] +@[simp] theorem getMsb_zeroExtend_add {x : BitVec w} (h : k ≤ i) : + (x.zeroExtend (w + k)).getMsb i = x.getMsb (i - k) := by + by_cases h : w = 0 + · subst h; simp + simp only [getMsb, getLsb_zeroExtend] + by_cases h₁ : i < w + k <;> by_cases h₂ : i - k < w <;> by_cases h₃ : w + k - 1 - i < w + k + <;> simp [h₁, h₂, h₃] + · congr 1 + omega + all_goals (first | apply getLsb_ge | apply Eq.symm; apply getLsb_ge) + <;> omega + @[simp] theorem getLsb_truncate (m : Nat) (x : BitVec n) (i : Nat) : getLsb (truncate m x) i = (decide (i < m) && getLsb x i) := getLsb_zeroExtend m x i @@ -301,11 +324,18 @@ theorem nat_eq_toNat (x : BitVec w) (y : Nat) (x.truncate l).truncate k = x.truncate k := zeroExtend_zeroExtend_of_le x h +@[simp] theorem truncate_cast {h : w = v} : (cast h x).truncate k = x.truncate k := by + apply eq_of_getLsb_eq + simp + theorem msb_zeroExtend (x : BitVec w) : (x.zeroExtend v).msb = (decide (0 < v) && x.getLsb (v - 1)) := by rw [msb_eq_getLsb_last] simp only [getLsb_zeroExtend] cases getLsb x (v - 1) <;> simp; omega +theorem msb_zeroExtend' (x : BitVec w) (h : w ≤ v) : (x.zeroExtend' h).msb = (decide (0 < v) && x.getLsb (v - 1)) := by + rw [zeroExtend'_eq, msb_zeroExtend] + /-! ## extractLsb -/ @[simp] @@ -353,6 +383,13 @@ protected theorem extractLsb_ofNat (x n : Nat) (hi lo : Nat) : rw [← testBit_toNat, getLsb, getLsb] simp +@[simp] theorem getMsb_or {x y : BitVec w} : (x ||| y).getMsb i = (x.getMsb i || y.getMsb i) := by + simp only [getMsb] + by_cases h : i < w <;> simp [h] + +@[simp] theorem msb_or {x y : BitVec w} : (x ||| y).msb = (x.msb || y.msb) := by + simp [BitVec.msb] + /-! ### and -/ @[simp] theorem toNat_and (x y : BitVec v) : @@ -367,6 +404,13 @@ protected theorem extractLsb_ofNat (x n : Nat) (hi lo : Nat) : rw [← testBit_toNat, getLsb, getLsb] simp +@[simp] theorem getMsb_and {x y : BitVec w} : (x &&& y).getMsb i = (x.getMsb i && y.getMsb i) := by + simp only [getMsb] + by_cases h : i < w <;> simp [h] + +@[simp] theorem msb_and {x y : BitVec w} : (x &&& y).msb = (x.msb && y.msb) := by + simp [BitVec.msb] + /-! ### xor -/ @[simp] theorem toNat_xor (x y : BitVec v) : @@ -431,6 +475,19 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl cases h₁ : decide (i < m) <;> cases h₂ : decide (n ≤ i) <;> cases h₃ : decide (i < n) all_goals { simp_all <;> omega } +@[simp] theorem getMsb_shiftLeft (x : BitVec w) (i) : + (x <<< i).getMsb k = x.getMsb (k + i) := by + simp only [getMsb, getLsb_shiftLeft] + by_cases h : w = 0 + · subst h; simp + have t : w - 1 - k < w := by omega + simp only [t] + simp only [decide_True, Nat.sub_sub, Bool.true_and, Nat.add_assoc] + by_cases h₁ : k < w <;> by_cases h₂ : w - (1 + k) < i <;> by_cases h₃ : k + i < w + <;> simp [h₁, h₂, h₃] + <;> (first | apply getLsb_ge | apply Eq.symm; apply getLsb_ge) + <;> omega + theorem shiftLeftZeroExtend_eq {x : BitVec w} : shiftLeftZeroExtend x n = zeroExtend (w+n) x <<< n := by apply eq_of_toNat_eq @@ -450,6 +507,10 @@ theorem shiftLeftZeroExtend_eq {x : BitVec w} : <;> simp_all <;> (rw [getLsb_ge]; omega) +@[simp] theorem msb_shiftLeftZeroExtend (x : BitVec w) (i : Nat) : + (shiftLeftZeroExtend x i).msb = x.msb := by + simp [shiftLeftZeroExtend_eq, BitVec.msb] + /-! ### ushiftRight -/ @[simp, bv_toNat] theorem toNat_ushiftRight (x : BitVec n) (i : Nat) : @@ -475,6 +536,34 @@ theorem append_def (x : BitVec v) (y : BitVec w) : · simp [h] · simp [h]; simp_all +theorem msb_append {x : BitVec w} {y : BitVec v} : + (x ++ y).msb = bif (w == 0) then (y.msb) else (x.msb) := by + rw [← append_eq, append] + simp [msb_zeroExtend'] + by_cases h : w = 0 + · subst h + simp [BitVec.msb, getMsb] + · rw [cond_eq_if] + have q : 0 < w + v := by omega + have t : y.getLsb (w + v - 1) = false := getLsb_ge _ _ (by omega) + simp [h, q, t, BitVec.msb, getMsb] + +@[simp] theorem truncate_append {x : BitVec w} {y : BitVec v} : + (x ++ y).truncate k = if h : k ≤ v then y.truncate k else (x.truncate (k - v) ++ y).cast (by omega) := by + apply eq_of_getLsb_eq + intro i + simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, getLsb_append, Bool.true_and] + split + · have t : i < v := by omega + simp [t] + · by_cases t : i < v + · simp [t] + · have t' : i - v < k - v := by omega + simp [t, t'] + +@[simp] theorem truncate_cons {x : BitVec w} : (cons a x).truncate w = x := by + simp [cons] + /-! ### rev -/ theorem getLsb_rev (x : BitVec w) (i : Fin w) : @@ -511,6 +600,9 @@ theorem getMsb_rev (x : BitVec w) (i : Fin w) : have p2 : i - n ≠ 0 := by omega simp [p1, p2, Nat.testBit_bool_to_nat] +@[simp] theorem msb_cons : (cons a x).msb = a := by + simp [cons, msb_cast, msb_append] + theorem truncate_succ (x : BitVec w) : truncate (i+1) x = cons (getLsb x i) (truncate i x) := by apply eq_of_getLsb_eq