From f10d0d07d9ac3ad0fc63d587108f056d565d2b91 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Wed, 29 Jan 2025 10:33:45 +1100 Subject: [PATCH] feat: lemmas about BitVec.setWidth (#6808) This PR adds simp lemmas replacing `BitVec.setWidth'` with `setWidth`, and conditionally simplifying `setWidth v (setWidth w v)`. --------- Co-authored-by: Tobias Grosser --- src/Init/Data/BitVec/Basic.lean | 3 +- src/Init/Data/BitVec/Lemmas.lean | 50 +++++++++++++++++++++++--------- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index ef48cc1d7c..f9fe3bf3b9 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -379,7 +379,8 @@ SMT-Lib name: `extract`. def extractLsb (hi lo : Nat) (x : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ x /-- -A version of `setWidth` that requires a proof, but is a noop. +A version of `setWidth` that requires a proof the new width is at least as large, +and is a computational noop. -/ def setWidth' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w := x.toNat#'(by diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 7a779cbcf3..b9d8410fff 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -605,12 +605,6 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} : (x.setWidth v).toFin = Fin.ofNat' (2^v) x.toNat := by ext; simp -theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by - apply eq_of_toNat_eq - rw [toNat_setWidth, toNat_setWidth'] - rw [Nat.mod_eq_of_lt] - exact Nat.lt_of_lt_of_le x.isLt (Nat.pow_le_pow_right (Nat.zero_lt_two) h) - @[simp] theorem setWidth_eq (x : BitVec n) : setWidth n x = x := by apply eq_of_toNat_eq let ⟨x, lt_n⟩ := x @@ -665,10 +659,10 @@ theorem getElem?_setWidth (m : Nat) (x : BitVec n) (i : Nat) : simp [getLsbD, toNat_setWidth'] @[simp] theorem getMsbD_setWidth' (ge : m ≥ n) (x : BitVec n) (i : Nat) : - getMsbD (setWidth' ge x) i = (decide (i ≥ m - n) && getMsbD x (i - (m - n))) := by + getMsbD (setWidth' ge x) i = (decide (m - n ≤ i) && getMsbD x (i + n - m)) := by simp only [getMsbD, getLsbD_setWidth', gt_iff_lt] - by_cases h₁ : decide (i < m) <;> by_cases h₂ : decide (i ≥ m - n) <;> by_cases h₃ : decide (i - (m - n) < n) <;> - by_cases h₄ : n - 1 - (i - (m - n)) = m - 1 - i + by_cases h₁ : decide (i < m) <;> by_cases h₂ : decide (m - n ≤ i) <;> by_cases h₃ : decide (i + n - m < n) <;> + by_cases h₄ : n - 1 - (i + n - m) = m - 1 - i all_goals simp only [h₁, h₂, h₃, h₄] simp_all only [ge_iff_le, decide_eq_true_eq, Nat.not_le, Nat.not_lt, Bool.true_and, @@ -681,7 +675,7 @@ theorem getElem?_setWidth (m : Nat) (x : BitVec n) (i : Nat) : getLsbD (setWidth m x) i = (decide (i < m) && getLsbD x i) := by simp [getLsbD, toNat_setWidth, Nat.testBit_mod_two_pow] -theorem getMsbD_setWidth {m : Nat} {x : BitVec n} {i : Nat} : +@[simp] theorem getMsbD_setWidth {m : Nat} {x : BitVec n} {i : Nat} : getMsbD (setWidth m x) i = (decide (m - n ≤ i) && getMsbD x (i + n - m)) := by unfold setWidth by_cases h : n ≤ m <;> simp only [h] @@ -695,6 +689,15 @@ theorem getMsbD_setWidth {m : Nat} {x : BitVec n} {i : Nat} : · simp [h'] omega +-- This is a simp lemma as there is only a runtime difference between `setWidth'` and `setWidth`, +-- and for verification purposes they are equivalent. +@[simp] +theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by + apply eq_of_toNat_eq + rw [toNat_setWidth, toNat_setWidth'] + rw [Nat.mod_eq_of_lt] + exact Nat.lt_of_lt_of_le x.isLt (Nat.pow_le_pow_right (Nat.zero_lt_two) h) + @[simp] theorem getMsbD_setWidth_add {x : BitVec w} (h : k ≤ i) : (x.setWidth (w + k)).getMsbD i = x.getMsbD (i - k) := by by_cases h : w = 0 @@ -765,6 +768,22 @@ theorem setWidth_one {x : BitVec w} : rw [Nat.mod_mod_of_dvd] exact Nat.pow_dvd_pow_iff_le_right'.mpr h +/-- +Iterated `setWidth` agrees with the second `setWidth` +except in the case the first `setWidth` is a non-trivial truncation, +and the second `setWidth` is a non-trivial extension. +-/ +-- Note that in the special cases `v = u` or `v = w`, +-- `simp` can discharge the side condition itself. +@[simp] theorem setWidth_setWidth {x : BitVec u} {w v : Nat} (h : ¬ (v < u ∧ v < w)) : + setWidth w (setWidth v x) = setWidth w x := by + ext + simp_all only [getLsbD_setWidth, decide_true, Bool.true_and, Bool.and_iff_right_iff_imp, + decide_eq_true_eq] + intro h + replace h := lt_of_getLsbD h + omega + /-! ## extractLsb -/ @[simp] @@ -1312,7 +1331,7 @@ theorem shiftLeftZeroExtend_eq {x : BitVec w} : apply eq_of_toNat_eq rw [shiftLeftZeroExtend, setWidth] split - · simp + · simp only [toNat_ofNatLt, toNat_shiftLeft, toNat_setWidth'] rw [Nat.mod_eq_of_lt] rw [Nat.shiftLeft_eq, Nat.pow_add] exact Nat.mul_lt_mul_of_pos_right x.isLt (Nat.two_pow_pos _) @@ -1336,11 +1355,15 @@ theorem shiftLeftZeroExtend_eq {x : BitVec w} : @[simp] theorem getMsbD_shiftLeftZeroExtend (x : BitVec m) (n : Nat) : getMsbD (shiftLeftZeroExtend x n) i = getMsbD x i := by + have : m + n - m ≤ i + n := by omega + have : i + n + m - (m + n) = i := by omega simp_all [shiftLeftZeroExtend_eq] @[simp] theorem msb_shiftLeftZeroExtend (x : BitVec w) (i : Nat) : (shiftLeftZeroExtend x i).msb = x.msb := by - simp [shiftLeftZeroExtend_eq, BitVec.msb] + have : w + i - w ≤ i := by omega + have : i + w - (w + i) = 0 := by omega + simp_all [shiftLeftZeroExtend_eq, BitVec.msb] theorem shiftLeft_add {w : Nat} (x : BitVec w) (n m : Nat) : x <<< (n + m) = (x <<< n) <<< m := by @@ -1903,8 +1926,9 @@ theorem getElem_append {x : BitVec n} {y : BitVec m} (h : i < n + m) : @[simp] theorem getMsbD_append {x : BitVec n} {y : BitVec m} : getMsbD (x ++ y) i = if n ≤ i then getMsbD y (i - n) else getMsbD x i := by simp only [append_def] + have : i + m - (n + m) = i - n := by omega by_cases h : n ≤ i - · simp [h] + · simp_all · simp [h] theorem msb_append {x : BitVec w} {y : BitVec v} :