From 4969ec9cdb9a4a6d148b46df3850cdabae3f70ae Mon Sep 17 00:00:00 2001 From: Mac Malone Date: Thu, 28 Nov 2024 21:08:52 -0500 Subject: [PATCH] feat: more UInt lemmas (#6205) This PR upstreams some UInt theorems from Batteries and adds more `toNat`-related theorems. It also adds the missing `UInt8` and `UInt16` to/from `USize` conversions so that the the interface is uniform across the UInt types. **Summary of all changes:** * Upstreamed and added `toNat` constructors lemmas: `toNat_mk`, `ofNat_toNat`, `toNat_ofNat`, `toNat_ofNatCore`, and `USize.toNat_ofNat32` * Upstreamed and added `toNat` canonicalization; `val_val_eq_toNat` and `toNat_toBitVec_eq_toNat` * Added injectivity iffs: `toBitVec_inj`, `toNat_inj`, and `val_inj` * Added inequality iffs: `le_iff_toNat_le` and `lt_iff_toNat_lt` * Upstreamed antisymmetry lemmas: `le_antisymm` and `le_antisymm_iff` * Upstreamed missing `toNat` lemmas on arithmetic operations: `toNat_add`, `toNat_sub`, `toNat_mul` * Upstreamed and added missing conversion lemmas: `toNat_toUInt*` and `toNat_USize` * Added missing `USize` conversions: `USize.toUInt8`, `UInt8.toUSize`, `USize.toUInt16`, `UInt16.toUSize` --- src/Init/Data/UInt/Basic.lean | 10 ++ src/Init/Data/UInt/Lemmas.lean | 249 +++++++++++++++++++++------------ src/include/lean/lean.h | 6 +- 3 files changed, 176 insertions(+), 89 deletions(-) diff --git a/src/Init/Data/UInt/Basic.lean b/src/Init/Data/UInt/Basic.lean index 2009ad0c63..8a41d8df34 100644 --- a/src/Init/Data/UInt/Basic.lean +++ b/src/Init/Data/UInt/Basic.lean @@ -278,6 +278,16 @@ This function is overridden with a native implementation. @[extern "lean_usize_of_nat"] def USize.ofNat32 (n : @& Nat) (h : n < 4294967296) : USize := USize.ofNatCore n (Nat.lt_of_lt_of_le h le_usize_size) +@[extern "lean_uint8_to_usize"] +def UInt8.toUSize (a : UInt8) : USize := + USize.ofNat32 a.toBitVec.toNat (Nat.lt_trans a.toBitVec.isLt (by decide)) +@[extern "lean_usize_to_uint8"] +def USize.toUInt8 (a : USize) : UInt8 := a.toNat.toUInt8 +@[extern "lean_uint16_to_usize"] +def UInt16.toUSize (a : UInt16) : USize := + USize.ofNat32 a.toBitVec.toNat (Nat.lt_trans a.toBitVec.isLt (by decide)) +@[extern "lean_usize_to_uint16"] +def USize.toUInt16 (a : USize) : UInt16 := a.toNat.toUInt16 @[extern "lean_uint32_to_usize"] def UInt32.toUSize (a : UInt32) : USize := USize.ofNat32 a.toBitVec.toNat a.toBitVec.isLt @[extern "lean_usize_to_uint32"] diff --git a/src/Init/Data/UInt/Lemmas.lean b/src/Init/Data/UInt/Lemmas.lean index 7601239b03..b94dd0c2c9 100644 --- a/src/Init/Data/UInt/Lemmas.lean +++ b/src/Init/Data/UInt/Lemmas.lean @@ -1,7 +1,7 @@ /- Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Leonardo de Moura +Authors: Leonardo de Moura, François G. Dorais, Mario Carneiro, Mac Malone -/ prelude import Init.Data.UInt.Basic @@ -9,129 +9,202 @@ import Init.Data.Fin.Lemmas import Init.Data.BitVec.Lemmas import Init.Data.BitVec.Bitblast +open Lean in set_option hygiene false in -macro "declare_uint_theorems" typeName:ident : command => -`( -namespace $typeName +macro "declare_uint_theorems" typeName:ident bits:term:arg : command => do + let mut cmds ← Syntax.getArgs <$> `( + namespace $typeName -instance : Inhabited $typeName where - default := 0 + theorem zero_def : (0 : $typeName) = ⟨0⟩ := rfl + theorem one_def : (1 : $typeName) = ⟨1⟩ := rfl + theorem sub_def (a b : $typeName) : a - b = ⟨a.toBitVec - b.toBitVec⟩ := rfl + theorem mul_def (a b : $typeName) : a * b = ⟨a.toBitVec * b.toBitVec⟩ := rfl + theorem mod_def (a b : $typeName) : a % b = ⟨a.toBitVec % b.toBitVec⟩ := rfl + theorem add_def (a b : $typeName) : a + b = ⟨a.toBitVec + b.toBitVec⟩ := rfl -theorem zero_def : (0 : $typeName) = ⟨0⟩ := rfl -theorem one_def : (1 : $typeName) = ⟨1⟩ := rfl -theorem sub_def (a b : $typeName) : a - b = ⟨a.toBitVec - b.toBitVec⟩ := rfl -theorem mul_def (a b : $typeName) : a * b = ⟨a.toBitVec * b.toBitVec⟩ := rfl -theorem mod_def (a b : $typeName) : a % b = ⟨a.toBitVec % b.toBitVec⟩ := rfl -theorem add_def (a b : $typeName) : a + b = ⟨a.toBitVec + b.toBitVec⟩ := rfl + @[simp] theorem toNat_mk : (mk a).toNat = a.toNat := rfl -@[simp] theorem mk_toBitVec_eq : ∀ (a : $typeName), mk a.toBitVec = a - | ⟨_, _⟩ => rfl + @[simp] theorem toNat_ofNat {n : Nat} : (ofNat n).toNat = n % 2 ^ $bits := BitVec.toNat_ofNat .. -theorem toBitVec_eq_of_lt {a : Nat} : a < size → (ofNat a).toBitVec.toNat = a := - Nat.mod_eq_of_lt + @[simp] theorem toNat_ofNatCore {n : Nat} {h : n < size} : (ofNatCore n h).toNat = n := BitVec.toNat_ofNatLt .. -theorem toNat_ofNat_of_lt {n : Nat} (h : n < size) : (ofNat n).toNat = n := by - rw [toNat, toBitVec_eq_of_lt h] + @[simp] theorem val_val_eq_toNat (x : $typeName) : x.val.val = x.toNat := rfl -theorem le_def {a b : $typeName} : a ≤ b ↔ a.toBitVec ≤ b.toBitVec := .rfl + theorem toNat_toBitVec_eq_toNat (x : $typeName) : x.toBitVec.toNat = x.toNat := rfl -theorem lt_def {a b : $typeName} : a < b ↔ a.toBitVec < b.toBitVec := .rfl + @[simp] theorem mk_toBitVec_eq : ∀ (a : $typeName), mk a.toBitVec = a + | ⟨_, _⟩ => rfl -@[simp] protected theorem not_le {a b : $typeName} : ¬ a ≤ b ↔ b < a := by simp [le_def, lt_def] + theorem toBitVec_eq_of_lt {a : Nat} : a < size → (ofNat a).toBitVec.toNat = a := + Nat.mod_eq_of_lt -@[simp] protected theorem not_lt {a b : $typeName} : ¬ a < b ↔ b ≤ a := by simp [le_def, lt_def] + theorem toNat_ofNat_of_lt {n : Nat} (h : n < size) : (ofNat n).toNat = n := by + rw [toNat, toBitVec_eq_of_lt h] -@[simp] protected theorem le_refl (a : $typeName) : a ≤ a := by simp [le_def] + theorem le_def {a b : $typeName} : a ≤ b ↔ a.toBitVec ≤ b.toBitVec := .rfl -@[simp] protected theorem lt_irrefl (a : $typeName) : ¬ a < a := by simp + theorem lt_def {a b : $typeName} : a < b ↔ a.toBitVec < b.toBitVec := .rfl -protected theorem le_trans {a b c : $typeName} : a ≤ b → b ≤ c → a ≤ c := BitVec.le_trans + theorem le_iff_toNat_le {a b : $typeName} : a ≤ b ↔ a.toNat ≤ b.toNat := .rfl -protected theorem lt_trans {a b c : $typeName} : a < b → b < c → a < c := BitVec.lt_trans + theorem lt_iff_toNat_lt {a b : $typeName} : a < b ↔ a.toNat < b.toNat := .rfl -protected theorem le_total (a b : $typeName) : a ≤ b ∨ b ≤ a := BitVec.le_total .. + @[simp] protected theorem not_le {a b : $typeName} : ¬ a ≤ b ↔ b < a := by simp [le_def, lt_def] -protected theorem lt_asymm {a b : $typeName} : a < b → ¬ b < a := BitVec.lt_asymm + @[simp] protected theorem not_lt {a b : $typeName} : ¬ a < b ↔ b ≤ a := by simp [le_def, lt_def] -protected theorem toBitVec_eq_of_eq {a b : $typeName} (h : a = b) : a.toBitVec = b.toBitVec := h ▸ rfl + @[simp] protected theorem le_refl (a : $typeName) : a ≤ a := by simp [le_def] -protected theorem eq_of_toBitVec_eq {a b : $typeName} (h : a.toBitVec = b.toBitVec) : a = b := by - cases a; cases b; simp_all + @[simp] protected theorem lt_irrefl (a : $typeName) : ¬ a < a := by simp -open $typeName (eq_of_toBitVec_eq) in -protected theorem eq_of_val_eq {a b : $typeName} (h : a.val = b.val) : a = b := by - rcases a with ⟨⟨_⟩⟩; rcases b with ⟨⟨_⟩⟩; simp_all [val] + protected theorem le_trans {a b c : $typeName} : a ≤ b → b ≤ c → a ≤ c := BitVec.le_trans -open $typeName (toBitVec_eq_of_eq) in -protected theorem ne_of_toBitVec_ne {a b : $typeName} (h : a.toBitVec ≠ b.toBitVec) : a ≠ b := - fun h' => absurd (toBitVec_eq_of_eq h') h + protected theorem lt_trans {a b c : $typeName} : a < b → b < c → a < c := BitVec.lt_trans -open $typeName (ne_of_toBitVec_ne) in -protected theorem ne_of_lt {a b : $typeName} (h : a < b) : a ≠ b := by - apply ne_of_toBitVec_ne - apply BitVec.ne_of_lt - simpa [lt_def] using h + protected theorem le_total (a b : $typeName) : a ≤ b ∨ b ≤ a := BitVec.le_total .. -@[simp] protected theorem toNat_zero : (0 : $typeName).toNat = 0 := Nat.zero_mod _ + protected theorem lt_asymm {a b : $typeName} : a < b → ¬ b < a := BitVec.lt_asymm -@[simp] protected theorem toNat_mod (a b : $typeName) : (a % b).toNat = a.toNat % b.toNat := BitVec.toNat_umod .. + protected theorem toBitVec_eq_of_eq {a b : $typeName} (h : a = b) : a.toBitVec = b.toBitVec := h ▸ rfl -@[simp] protected theorem toNat_div (a b : $typeName) : (a / b).toNat = a.toNat / b.toNat := BitVec.toNat_udiv .. + protected theorem eq_of_toBitVec_eq {a b : $typeName} (h : a.toBitVec = b.toBitVec) : a = b := by + cases a; cases b; simp_all -@[simp] protected theorem toNat_sub_of_le (a b : $typeName) : b ≤ a → (a - b).toNat = a.toNat - b.toNat := BitVec.toNat_sub_of_le + open $typeName (eq_of_toBitVec_eq toBitVec_eq_of_eq) in + protected theorem toBitVec_inj {a b : $typeName} : a.toBitVec = b.toBitVec ↔ a = b := + Iff.intro eq_of_toBitVec_eq toBitVec_eq_of_eq -protected theorem toNat_lt_size (a : $typeName) : a.toNat < size := a.toBitVec.isLt + open $typeName (eq_of_toBitVec_eq) in + protected theorem eq_of_val_eq {a b : $typeName} (h : a.val = b.val) : a = b := by + rcases a with ⟨⟨_⟩⟩; rcases b with ⟨⟨_⟩⟩; simp_all [val] -open $typeName (toNat_mod toNat_lt_size) in -protected theorem toNat_mod_lt {m : Nat} : ∀ (u : $typeName), m > 0 → toNat (u % ofNat m) < m := by - intro u h1 - by_cases h2 : m < size - · rw [toNat_mod, toNat_ofNat_of_lt h2] - apply Nat.mod_lt _ h1 - · apply Nat.lt_of_lt_of_le - · apply toNat_lt_size - · simpa using h2 + open $typeName (eq_of_val_eq) in + protected theorem val_inj {a b : $typeName} : a.val = b.val ↔ a = b := + Iff.intro eq_of_val_eq (congrArg val) -open $typeName (toNat_mod_lt) in -set_option linter.deprecated false in -@[deprecated toNat_mod_lt (since := "2024-09-24")] -protected theorem modn_lt {m : Nat} : ∀ (u : $typeName), m > 0 → toNat (u % m) < m := by - intro u - simp only [(· % ·)] - simp only [gt_iff_lt, toNat, modn, Fin.modn_val, BitVec.natCast_eq_ofNat, BitVec.toNat_ofNat, - Nat.reducePow] - rw [Nat.mod_eq_of_lt] - · apply Nat.mod_lt - · apply Nat.lt_of_le_of_lt - · apply Nat.mod_le - · apply Fin.is_lt + open $typeName (toBitVec_eq_of_eq) in + protected theorem ne_of_toBitVec_ne {a b : $typeName} (h : a.toBitVec ≠ b.toBitVec) : a ≠ b := + fun h' => absurd (toBitVec_eq_of_eq h') h -protected theorem mod_lt (a : $typeName) {b : $typeName} : 0 < b → a % b < b := by - simp only [lt_def, mod_def] - apply BitVec.umod_lt + open $typeName (ne_of_toBitVec_ne) in + protected theorem ne_of_lt {a b : $typeName} (h : a < b) : a ≠ b := by + apply ne_of_toBitVec_ne + apply BitVec.ne_of_lt + simpa [lt_def] using h -protected theorem toNat.inj : ∀ {a b : $typeName}, a.toNat = b.toNat → a = b - | ⟨_, _⟩, ⟨_, _⟩, rfl => rfl + @[simp] protected theorem toNat_zero : (0 : $typeName).toNat = 0 := Nat.zero_mod _ -@[simp] protected theorem ofNat_one : ofNat 1 = 1 := rfl + @[simp] protected theorem toNat_add (a b : $typeName) : (a + b).toNat = (a.toNat + b.toNat) % 2 ^ $bits := BitVec.toNat_add .. -@[simp] -theorem val_ofNat (n : Nat) : val (no_index (OfNat.ofNat n)) = OfNat.ofNat n := rfl + protected theorem toNat_sub (a b : $typeName) : (a - b).toNat = (2 ^ $bits - b.toNat + a.toNat) % 2 ^ $bits := BitVec.toNat_sub .. -@[simp] -theorem toBitVec_ofNat (n : Nat) : toBitVec (no_index (OfNat.ofNat n)) = BitVec.ofNat _ n := rfl + @[simp] protected theorem toNat_mul (a b : $typeName) : (a * b).toNat = a.toNat * b.toNat % 2 ^ $bits := BitVec.toNat_mul .. -@[simp] -theorem mk_ofNat (n : Nat) : mk (BitVec.ofNat _ n) = OfNat.ofNat n := rfl + @[simp] protected theorem toNat_mod (a b : $typeName) : (a % b).toNat = a.toNat % b.toNat := BitVec.toNat_umod .. -end $typeName -) + @[simp] protected theorem toNat_div (a b : $typeName) : (a / b).toNat = a.toNat / b.toNat := BitVec.toNat_udiv .. -declare_uint_theorems UInt8 -declare_uint_theorems UInt16 -declare_uint_theorems UInt32 -declare_uint_theorems UInt64 -declare_uint_theorems USize + @[simp] protected theorem toNat_sub_of_le (a b : $typeName) : b ≤ a → (a - b).toNat = a.toNat - b.toNat := BitVec.toNat_sub_of_le + + protected theorem toNat_lt_size (a : $typeName) : a.toNat < size := a.toBitVec.isLt + + open $typeName (toNat_mod toNat_lt_size) in + protected theorem toNat_mod_lt {m : Nat} : ∀ (u : $typeName), m > 0 → toNat (u % ofNat m) < m := by + intro u h1 + by_cases h2 : m < size + · rw [toNat_mod, toNat_ofNat_of_lt h2] + apply Nat.mod_lt _ h1 + · apply Nat.lt_of_lt_of_le + · apply toNat_lt_size + · simpa using h2 + + open $typeName (toNat_mod_lt) in + set_option linter.deprecated false in + @[deprecated toNat_mod_lt (since := "2024-09-24")] + protected theorem modn_lt {m : Nat} : ∀ (u : $typeName), m > 0 → toNat (u % m) < m := by + intro u + simp only [(· % ·)] + simp only [gt_iff_lt, toNat, modn, Fin.modn_val, BitVec.natCast_eq_ofNat, BitVec.toNat_ofNat, + Nat.reducePow] + rw [Nat.mod_eq_of_lt] + · apply Nat.mod_lt + · apply Nat.lt_of_le_of_lt + · apply Nat.mod_le + · apply Fin.is_lt + + protected theorem mod_lt (a : $typeName) {b : $typeName} : 0 < b → a % b < b := by + simp only [lt_def, mod_def] + apply BitVec.umod_lt + + protected theorem toNat.inj : ∀ {a b : $typeName}, a.toNat = b.toNat → a = b + | ⟨_, _⟩, ⟨_, _⟩, rfl => rfl + + protected theorem toNat_inj : ∀ {a b : $typeName}, a.toNat = b.toNat ↔ a = b := + Iff.intro toNat.inj (congrArg toNat) + + open $typeName (toNat_inj) in + protected theorem le_antisymm_iff {a b : $typeName} : a = b ↔ a ≤ b ∧ b ≤ a := + toNat_inj.symm.trans Nat.le_antisymm_iff + + open $typeName (le_antisymm_iff) in + protected theorem le_antisymm {a b : $typeName} (h₁ : a ≤ b) (h₂ : b ≤ a) : a = b := + le_antisymm_iff.2 ⟨h₁, h₂⟩ + + @[simp] protected theorem ofNat_one : ofNat 1 = 1 := rfl + + @[simp] protected theorem ofNat_toNat {x : $typeName} : ofNat x.toNat = x := by + apply toNat.inj + simp [Nat.mod_eq_of_lt x.toNat_lt_size] + + @[simp] + theorem val_ofNat (n : Nat) : val (no_index (OfNat.ofNat n)) = OfNat.ofNat n := rfl + + @[simp] + theorem toBitVec_ofNat (n : Nat) : toBitVec (no_index (OfNat.ofNat n)) = BitVec.ofNat _ n := rfl + + @[simp] + theorem mk_ofNat (n : Nat) : mk (BitVec.ofNat _ n) = OfNat.ofNat n := rfl + + ) + if let some nbits := bits.raw.isNatLit? then + if nbits > 8 then + cmds := cmds.push <| ← + `(@[simp] theorem toNat_toUInt8 (x : $typeName) : x.toUInt8.toNat = x.toNat % 2 ^ 8 := rfl) + if nbits < 16 then + cmds := cmds.push <| ← + `(@[simp] theorem toNat_toUInt16 (x : $typeName) : x.toUInt16.toNat = x.toNat := rfl) + else if nbits > 16 then + cmds := cmds.push <| ← + `(@[simp] theorem toNat_toUInt16 (x : $typeName) : x.toUInt16.toNat = x.toNat % 2 ^ 16 := rfl) + if nbits < 32 then + cmds := cmds.push <| ← + `(@[simp] theorem toNat_toUInt32 (x : $typeName) : x.toUInt32.toNat = x.toNat := rfl) + else if nbits > 32 then + cmds := cmds.push <| ← + `(@[simp] theorem toNat_toUInt32 (x : $typeName) : x.toUInt32.toNat = x.toNat % 2 ^ 32 := rfl) + if nbits ≤ 32 then + cmds := cmds.push <| ← + `(@[simp] theorem toNat_toUSize (x : $typeName) : x.toUSize.toNat = x.toNat := rfl) + else + cmds := cmds.push <| ← + `(@[simp] theorem toNat_toUSize (x : $typeName) : x.toUSize.toNat = x.toNat % 2 ^ System.Platform.numBits := rfl) + if nbits < 64 then + cmds := cmds.push <| ← + `(@[simp] theorem toNat_toUInt64 (x : $typeName) : x.toUInt64.toNat = x.toNat := rfl) + cmds := cmds.push <| ← `(end $typeName) + return ⟨mkNullNode cmds⟩ + +declare_uint_theorems UInt8 8 +declare_uint_theorems UInt16 16 +declare_uint_theorems UInt32 32 +declare_uint_theorems UInt64 64 +declare_uint_theorems USize System.Platform.numBits + +@[simp] theorem USize.toNat_ofNat32 {n : Nat} {h : n < 4294967296} : (ofNat32 n h).toNat = n := rfl + +@[simp] theorem USize.toNat_toUInt32 (x : USize) : x.toUInt32.toNat = x.toNat % 2 ^ 32 := rfl + +@[simp] theorem USize.toNat_toUInt64 (x : USize) : x.toUInt64.toNat = x.toNat := rfl theorem USize.toNat_ofNat_of_lt_32 {n : Nat} (h : n < 4294967296) : toNat (ofNat n) = n := toNat_ofNat_of_lt (Nat.lt_of_lt_of_le h le_usize_size) diff --git a/src/include/lean/lean.h b/src/include/lean/lean.h index c0b127058f..5356f719be 100644 --- a/src/include/lean/lean.h +++ b/src/include/lean/lean.h @@ -1692,6 +1692,7 @@ static inline uint8_t lean_uint8_dec_le(uint8_t a1, uint8_t a2) { return a1 <= a static inline uint16_t lean_uint8_to_uint16(uint8_t a) { return ((uint16_t)a); } static inline uint32_t lean_uint8_to_uint32(uint8_t a) { return ((uint32_t)a); } static inline uint64_t lean_uint8_to_uint64(uint8_t a) { return ((uint64_t)a); } +static inline size_t lean_uint8_to_usize(uint8_t a) { return ((size_t)a); } /* UInt16 */ @@ -1727,6 +1728,7 @@ static inline uint8_t lean_uint16_dec_le(uint16_t a1, uint16_t a2) { return a1 < static inline uint8_t lean_uint16_to_uint8(uint16_t a) { return ((uint8_t)a); } static inline uint32_t lean_uint16_to_uint32(uint16_t a) { return ((uint32_t)a); } static inline uint64_t lean_uint16_to_uint64(uint16_t a) { return ((uint64_t)a); } +static inline size_t lean_uint16_to_usize(uint16_t a) { return ((size_t)a); } /* UInt32 */ @@ -1762,7 +1764,7 @@ static inline uint8_t lean_uint32_dec_le(uint32_t a1, uint32_t a2) { return a1 < static inline uint8_t lean_uint32_to_uint8(uint32_t a) { return ((uint8_t)a); } static inline uint16_t lean_uint32_to_uint16(uint32_t a) { return ((uint16_t)a); } static inline uint64_t lean_uint32_to_uint64(uint32_t a) { return ((uint64_t)a); } -static inline size_t lean_uint32_to_usize(uint32_t a) { return a; } +static inline size_t lean_uint32_to_usize(uint32_t a) { return ((size_t)a); } /* UInt64 */ @@ -1834,6 +1836,8 @@ static inline uint8_t lean_usize_dec_le(size_t a1, size_t a2) { return a1 <= a2; /* usize -> other */ +static inline uint8_t lean_usize_to_uint8(size_t a) { return ((uint8_t)a); } +static inline uint16_t lean_usize_to_uint16(size_t a) { return ((uint16_t)a); } static inline uint32_t lean_usize_to_uint32(size_t a) { return ((uint32_t)a); } static inline uint64_t lean_usize_to_uint64(size_t a) { return ((uint64_t)a); }