From 0d30517dca094a07bcb462252f718e713b93ffba Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 24 May 2024 06:34:56 -0700 Subject: [PATCH] feat: make `#` bitvector literal notation global chore: `toFin_ofNat` --- src/Init/Data/BitVec/Basic.lean | 8 ++++---- src/Init/Data/BitVec/Lemmas.lean | 32 ++++++++++++++++---------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 1d12641f0a..431df33bd0 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -151,12 +151,12 @@ end Int section Syntax /-- Notation for bit vector literals. `i#n` is a shorthand for `BitVec.ofNat n i`. -/ -scoped syntax:max term:max noWs "#" noWs term:max : term -macro_rules | `($i#$n) => `(BitVec.ofNat $n $i) +syntax:max num noWs "#" noWs term:max : term +macro_rules | `($i:num#$n) => `(BitVec.ofNat $n $i) /-- Unexpander for bit vector literals. -/ @[app_unexpander BitVec.ofNat] def unexpandBitVecOfNat : Lean.PrettyPrinter.Unexpander - | `($(_) $n $i) => `($i#$n) + | `($(_) $n $i:num) => `($i:num#$n) | _ => throw () /-- Notation for bit vector literals without truncation. `i#'lt` is a shorthand for `BitVec.ofNatLt i lt`. -/ @@ -504,7 +504,7 @@ equivalent to `a * 2^s`, modulo `2^n`. SMT-Lib name: `bvshl` except this operator uses a `Nat` shift value. -/ -protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := (a.toNat <<< s)#n +protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (a.toNat <<< s) instance : HShiftLeft (BitVec w) Nat (BitVec w) := ⟨.shiftLeft⟩ /-- diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index bdd0f5cb0d..a1a27bafda 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -139,15 +139,15 @@ theorem ofBool_eq_iff_eq : ∀(b b' : Bool), BitVec.ofBool b = BitVec.ofBool b' getLsb (x#'lt) i = x.testBit i := by simp [getLsb, BitVec.ofNatLt] -@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (x#w).toNat = x % 2^w := by +@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (BitVec.ofNat w x).toNat = x % 2^w := by simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat'] -@[simp] theorem toFin_ofNat (x : Nat) : toFin x#w = Fin.ofNat' x (Nat.two_pow_pos w) := rfl +@[simp] theorem toFin_ofNat (x : Nat) : toFin (BitVec.ofNat w x) = Fin.ofNat' x (Nat.two_pow_pos w) := rfl -- Remark: we don't use `[simp]` here because simproc` subsumes it for literals. -- If `x` and `n` are not literals, applying this theorem eagerly may not be a good idea. theorem getLsb_ofNat (n : Nat) (x : Nat) (i : Nat) : - getLsb (x#n) i = (i < n && x.testBit i) := by + getLsb (BitVec.ofNat n x) i = (i < n && x.testBit i) := by simp [getLsb, BitVec.ofNat, Fin.val_ofNat'] @[simp, deprecated toNat_ofNat (since := "2024-02-22")] @@ -316,19 +316,19 @@ theorem zeroExtend'_eq {x : BitVec w} (h : w ≤ v) : x.zeroExtend' h = x.zeroEx let ⟨x, lt_n⟩ := x simp [truncate, zeroExtend] -@[simp] theorem zeroExtend_zero (m n : Nat) : zeroExtend m (0#n) = 0#m := by +@[simp] theorem zeroExtend_zero (m n : Nat) : zeroExtend m 0#n = 0#m := by apply eq_of_toNat_eq simp [toNat_zeroExtend] @[simp] theorem truncate_eq (x : BitVec n) : truncate n x = x := zeroExtend_eq x -@[simp] theorem ofNat_toNat (m : Nat) (x : BitVec n) : x.toNat#m = truncate m x := by +@[simp] theorem ofNat_toNat (m : Nat) (x : BitVec n) : BitVec.ofNat m x.toNat = truncate m x := by apply eq_of_toNat_eq simp /-- Moves one-sided left toNat equality to BitVec equality. -/ theorem toNat_eq_nat (x : BitVec w) (y : Nat) - : (x.toNat = y) ↔ (y < 2^w ∧ (x = y#w)) := by + : (x.toNat = y) ↔ (y < 2^w ∧ (x = BitVec.ofNat w y)) := by apply Iff.intro · intro eq simp at eq @@ -340,7 +340,7 @@ theorem toNat_eq_nat (x : BitVec w) (y : Nat) /-- Moves one-sided right toNat equality to BitVec equality. -/ theorem nat_eq_toNat (x : BitVec w) (y : Nat) - : (y = x.toNat) ↔ (y < 2^w ∧ (x = y#w)) := by + : (y = x.toNat) ↔ (y < 2^w ∧ (x = BitVec.ofNat w y)) := by rw [@eq_comm _ _ x.toNat] apply toNat_eq_nat @@ -416,7 +416,7 @@ protected theorem extractLsb_ofFin {n} (x : Fin (2^n)) (hi lo : Nat) : @[simp] protected theorem extractLsb_ofNat (x n : Nat) (hi lo : Nat) : - extractLsb hi lo x#n = .ofNat (hi - lo + 1) ((x % 2^n) >>> lo) := by + extractLsb hi lo (BitVec.ofNat n x) = .ofNat (hi - lo + 1) ((x % 2^n) >>> lo) := by apply eq_of_getLsb_eq intro ⟨i, _lt⟩ simp [BitVec.ofNat] @@ -1008,10 +1008,10 @@ Definition of bitvector addition as a nat. @[simp] theorem add_ofFin (x : BitVec n) (y : Fin (2^n)) : x + .ofFin y = .ofFin (x.toFin + y) := rfl -theorem ofNat_add {n} (x y : Nat) : (x + y)#n = x#n + y#n := by +theorem ofNat_add {n} (x y : Nat) : BitVec.ofNat n (x + y) = BitVec.ofNat n x + BitVec.ofNat n y := by apply eq_of_toNat_eq ; simp [BitVec.ofNat] -theorem ofNat_add_ofNat {n} (x y : Nat) : x#n + y#n = (x + y)#n := +theorem ofNat_add_ofNat {n} (x y : Nat) : BitVec.ofNat n x + BitVec.ofNat n y = BitVec.ofNat n (x + y) := (ofNat_add x y).symm protected theorem add_assoc (x y z : BitVec n) : x + y + z = x + (y + z) := by @@ -1057,10 +1057,10 @@ theorem sub_def {n} (x y : BitVec n) : x - y = .ofNat n (x.toNat + (2^n - y.toNa rfl -- Remark: we don't use `[simp]` here because simproc` subsumes it for literals. -- If `x` and `n` are not literals, applying this theorem eagerly may not be a good idea. -theorem ofNat_sub_ofNat {n} (x y : Nat) : x#n - y#n = .ofNat n (x + (2^n - y % 2^n)) := by +theorem ofNat_sub_ofNat {n} (x y : Nat) : BitVec.ofNat n x - BitVec.ofNat n y = .ofNat n (x + (2^n - y % 2^n)) := by apply eq_of_toNat_eq ; simp [BitVec.ofNat] -@[simp] protected theorem sub_zero (x : BitVec n) : x - (0#n) = x := by apply eq_of_toNat_eq ; simp +@[simp] protected theorem sub_zero (x : BitVec n) : x - 0#n = x := by apply eq_of_toNat_eq ; simp @[simp] protected theorem sub_self (x : BitVec n) : x - x = 0#n := by apply eq_of_toNat_eq @@ -1080,7 +1080,7 @@ theorem sub_toAdd {n} (x y : BitVec n) : x - y = x + - y := by apply eq_of_toNat_eq simp -@[simp] theorem neg_zero (n:Nat) : -0#n = 0#n := by apply eq_of_toNat_eq ; simp +@[simp] theorem neg_zero (n:Nat) : -BitVec.ofNat n 0 = BitVec.ofNat n 0 := by apply eq_of_toNat_eq ; simp theorem add_sub_cancel (x y : BitVec w) : x + y - y = x := by apply eq_of_toNat_eq @@ -1157,7 +1157,7 @@ theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) = x ≤ BitVec.ofFin y ↔ x.toFin ≤ y := Iff.rfl @[simp] theorem ofFin_le (x : Fin (2^n)) (y : BitVec n) : BitVec.ofFin x ≤ y ↔ x ≤ y.toFin := Iff.rfl -@[simp] theorem ofNat_le_ofNat {n} (x y : Nat) : (x#n) ≤ (y#n) ↔ x % 2^n ≤ y % 2^n := by +@[simp] theorem ofNat_le_ofNat {n} (x y : Nat) : (BitVec.ofNat n x) ≤ (BitVec.ofNat n y) ↔ x % 2^n ≤ y % 2^n := by simp [le_def] @[bv_toNat] theorem lt_def (x y : BitVec n) : @@ -1167,7 +1167,7 @@ theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) = x < BitVec.ofFin y ↔ x.toFin < y := Iff.rfl @[simp] theorem ofFin_lt (x : Fin (2^n)) (y : BitVec n) : BitVec.ofFin x < y ↔ x < y.toFin := Iff.rfl -@[simp] theorem ofNat_lt_ofNat {n} (x y : Nat) : (x#n) < (y#n) ↔ x % 2^n < y % 2^n := by +@[simp] theorem ofNat_lt_ofNat {n} (x y : Nat) : BitVec.ofNat n x < BitVec.ofNat n y ↔ x % 2^n < y % 2^n := by simp [lt_def] protected theorem lt_of_le_ne (x y : BitVec n) (h1 : x <= y) (h2 : ¬ x = y) : x < y := by @@ -1180,7 +1180,7 @@ protected theorem lt_of_le_ne (x y : BitVec n) (h1 : x <= y) (h2 : ¬ x = y) : x /-! ### intMax -/ /-- The bitvector of width `w` that has the largest value when interpreted as an integer. -/ -def intMax (w : Nat) : BitVec w := (2^w - 1)#w +def intMax (w : Nat) : BitVec w := BitVec.ofNat w (2^w - 1) theorem getLsb_intMax_eq (w : Nat) : (intMax w).getLsb i = decide (i < w) := by simp [intMax, getLsb]