diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index b9d8410fff..edc71c55b6 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1193,6 +1193,10 @@ theorem not_not {b : BitVec w} : ~~~(~~~b) = b := by ext i h simp [h] +@[simp] +protected theorem not_inj {x y : BitVec w} : ~~~x = ~~~y ↔ x = y := + ⟨fun h => by rw [← @not_not w x, ← @not_not w y, h], congrArg _⟩ + @[simp] theorem and_not_self (x : BitVec n) : x &&& ~~~x = 0 := by ext i simp_all @@ -2515,6 +2519,10 @@ theorem neg_neg {x : BitVec w} : - - x = x := by · simp [h] · simp [bv_toNat, h] +@[simp] +protected theorem neg_inj {x y : BitVec w} : -x = -y ↔ x = y := + ⟨fun h => by rw [← @neg_neg w x, ← @neg_neg w y, h], congrArg _⟩ + theorem neg_ne_iff_ne_neg {x y : BitVec w} : -x ≠ y ↔ x ≠ -y := by constructor all_goals @@ -2557,6 +2565,27 @@ theorem not_neg (x : BitVec w) : ~~~(-x) = x + -1#w := by show (_ - x.toNat) % _ = _ by rw [Nat.mod_eq_of_lt (by omega)]] omega +/- ### add/sub injectivity -/ + +@[simp] +protected theorem add_left_inj {x y : BitVec w} (z : BitVec w) : (x + z = y + z) ↔ x = y := by + apply Iff.intro + · intro p + rw [← add_sub_cancel x z, ← add_sub_cancel y z, p] + · exact congrArg (· + z) + +@[simp] +protected theorem add_right_inj {x y : BitVec w} (z : BitVec w) : (z + x = z + y) ↔ x = y := by + simp [BitVec.add_comm z] + +@[simp] +protected theorem sub_left_inj {x y : BitVec w} (z : BitVec w) : (x - z = y - z) ↔ x = y := by + simp [sub_toAdd] + +@[simp] +protected theorem sub_right_inj {x y : BitVec w} (z : BitVec w) : (z - x = z - y) ↔ x = y := by + simp [sub_toAdd] + /-! ### fill -/ @[simp] diff --git a/src/Std/Tactic/BVDecide/Normalize/Equal.lean b/src/Std/Tactic/BVDecide/Normalize/Equal.lean index da72bcd885..389d8351d2 100644 --- a/src/Std/Tactic/BVDecide/Normalize/Equal.lean +++ b/src/Std/Tactic/BVDecide/Normalize/Equal.lean @@ -24,11 +24,26 @@ theorem Bool.not_beq_not : ∀ (a b : Bool), ((!a) == (!b)) = (a == b) := by @[bv_normalize] theorem BitVec.not_beq_not (a b : BitVec w) : (~~~a == ~~~b) = (a == b) := by - match h : a == b with - | true => simp_all - | false => - simp only [beq_eq_false_iff_ne, ne_eq] at * - bv_omega + rw [Bool.eq_iff_iff] + simp + +@[bv_normalize] +theorem BitVec.add_left_inj (a b c : BitVec w) : (a + c == b + c) = (a == b) := by + rw [Bool.eq_iff_iff] + simp + +@[bv_normalize] +theorem BitVec.add_left_inj' (a b c : BitVec w) : (a + c == c + b) = (a == b) := by + rw [BitVec.add_comm c b, add_left_inj] + +@[bv_normalize] +theorem BitVec.add_right_inj (a b c : BitVec w) : (c + a == c + b) = (a == b) := by + rw [Bool.eq_iff_iff] + simp + +@[bv_normalize] +theorem BitVec.add_right_inj' (a b c : BitVec w) : (c + a == b + c) = (a == b) := by + rw [BitVec.add_comm b c, add_right_inj] end Frontend.Normalize end Std.Tactic.BVDecide diff --git a/tests/lean/run/bv_decide_rewriter.lean b/tests/lean/run/bv_decide_rewriter.lean index b29811ca0f..0b36cbcb1f 100644 --- a/tests/lean/run/bv_decide_rewriter.lean +++ b/tests/lean/run/bv_decide_rewriter.lean @@ -104,6 +104,12 @@ example (x : BitVec 16) : (x.ult 1) = (x == 0) := by bv_normalize -- ushiftRight_self example (x : BitVec 16) : (x >>> x) == 0 := by bv_normalize +-- add_left_inj / add_right_inj +example (x y z : BitVec 16) : (x + z == y + z) = (x == y) := by bv_normalize +example (x y z : BitVec 16) : (x + z == z + y) = (x == y) := by bv_normalize +example (x y z : BitVec 16) : (z + x == y + z) = (x == y) := by bv_normalize +example (x y z : BitVec 16) : (z + x == z + y) = (x == y) := by bv_normalize + section example (x y : BitVec 256) : x * y = y * x := by