diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 5c16ca74c4..1c208e42a6 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -536,6 +536,15 @@ def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s) instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩ instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩ +/-- +Arithmetic right shift for bit vectors. The high bits are filled with the +most-significant bit. +As a numeric operation, this is equivalent to `a.toInt >>> s.toNat`. + +SMT-Lib name: `bvashr`. +-/ +def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat + /-- Auxiliary function for `rotateLeft`, which does not take into account the case where the rotation amount is greater than the bitvector width. -/ def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w := diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 53d8bec42a..25f93feabb 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -429,6 +429,67 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) : · simp [of_length_zero] · simp [shiftLeftRec_eq] +/- ### Arithmetic shift right (sshiftRight) recurrence -/ + +/-- +`sshiftRightRec x y n` shifts `x` arithmetically/signed to the right by the first `n` bits of `y`. +The theorem `sshiftRight_eq_sshiftRightRec` proves the equivalence of `(x.sshiftRight y)` and `sshiftRightRec`. +Together with equations `sshiftRightRec_zero`, `sshiftRightRec_succ`, +this allows us to unfold `sshiftRight` into a circuit for bitblasting. +-/ +def sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := + let shiftAmt := (y &&& (twoPow w₂ n)) + match n with + | 0 => x.sshiftRight' shiftAmt + | n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt + +@[simp] +theorem sshiftRightRec_zero_eq (x : BitVec w₁) (y : BitVec w₂) : + sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by + simp only [sshiftRightRec, twoPow_zero] + +@[simp] +theorem sshiftRightRec_succ_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : + sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by + simp [sshiftRightRec] + +/-- +If `y &&& z = 0`, `x.sshiftRight (y ||| z) = (x.sshiftRight y).sshiftRight z`. +This follows as `y &&& z = 0` implies `y ||| z = y + z`, +and thus `x.sshiftRight (y ||| z) = x.sshiftRight (y + z) = (x.sshiftRight y).sshiftRight z`. +-/ +theorem sshiftRight'_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂} + (h : y &&& z = 0#w₂) : + x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by + simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h, + toNat_add_of_and_eq_zero h, sshiftRight_add] + +theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : + sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by + induction n generalizing x y + case zero => + ext i + simp [twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb, truncate_one] + case succ n ih => + simp only [sshiftRightRec_succ_eq, and_twoPow, ih] + by_cases h : y.getLsb (n + 1) + · rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h, + sshiftRight'_or_of_and_eq_zero (by simp), h] + simp + · rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1) + (by simp [h])] + simp [h] + +/-- +Show that `x.sshiftRight y` can be written in terms of `sshiftRightRec`. +This can be unfolded in terms of `sshiftRightRec_zero_eq`, `sshiftRightRec_succ_eq` for bitblasting. +-/ +theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) : + (x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by + rcases w₂ with rfl | w₂ + · simp [of_length_zero] + · simp [sshiftRightRec_eq] + /- ### Logical shift right (ushiftRight) recurrence for bitblasting -/ /-- diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index da1ca96bc9..6b36f912b9 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -786,7 +786,7 @@ theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) : · rw [Nat.shiftRight_eq_div_pow] apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega) -theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) : +@[simp] theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) : getLsb (x.sshiftRight s) i = (!decide (w ≤ i) && if s + i < w then x.getLsb (s + i) else x.msb) := by rcases hmsb : x.msb with rfl | rfl @@ -807,6 +807,41 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) : Nat.not_lt, decide_eq_true_eq] omega +/-- The msb after arithmetic shifting right equals the original msb. -/ +theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} : + (x.sshiftRight n).msb = x.msb := by + rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last] + by_cases hw₀ : w = 0 + · simp [hw₀] + · simp only [show ¬(w ≤ w - 1) by omega, decide_False, Bool.not_false, Bool.true_and, + ite_eq_right_iff] + intros h + simp [show n = 0 by omega] + +@[simp] theorem sshiftRight_zero {x : BitVec w} : x.sshiftRight 0 = x := by + ext i + simp + +theorem sshiftRight_add {x : BitVec w} {m n : Nat} : + x.sshiftRight (m + n) = (x.sshiftRight m).sshiftRight n := by + ext i + simp only [getLsb_sshiftRight, Nat.add_assoc] + by_cases h₁ : w ≤ (i : Nat) + · simp [h₁] + · simp only [h₁, decide_False, Bool.not_false, Bool.true_and] + by_cases h₂ : n + ↑i < w + · simp [h₂] + · simp only [h₂, ↓reduceIte] + by_cases h₃ : m + (n + ↑i) < w + · simp [h₃] + omega + · simp [h₃, sshiftRight_msb_eq_msb] + +/-! ### sshiftRight reductions from BitVec to Nat -/ + +@[simp] +theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNat := rfl + /-! ### signExtend -/ /-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/