feat: more UInt lemmas (#6205)

This PR upstreams some UInt theorems from Batteries and adds more
`toNat`-related theorems. It also adds the missing `UInt8` and `UInt16`
to/from `USize` conversions so that the the interface is uniform across
the UInt types.

**Summary of all changes:**

* Upstreamed and added `toNat` constructors lemmas: `toNat_mk`,
`ofNat_toNat`, `toNat_ofNat`, `toNat_ofNatCore`, and
`USize.toNat_ofNat32`
* Upstreamed and added `toNat` canonicalization; `val_val_eq_toNat` and
`toNat_toBitVec_eq_toNat`
* Added injectivity iffs: `toBitVec_inj`, `toNat_inj`, and `val_inj`
* Added inequality iffs: `le_iff_toNat_le` and `lt_iff_toNat_lt`
* Upstreamed antisymmetry lemmas: `le_antisymm` and `le_antisymm_iff`
* Upstreamed missing `toNat` lemmas on arithmetic operations:
`toNat_add`, `toNat_sub`, `toNat_mul`
* Upstreamed and added missing conversion lemmas: `toNat_toUInt*` and
`toNat_USize`
* Added missing `USize` conversions: `USize.toUInt8`, `UInt8.toUSize`,
`USize.toUInt16`, `UInt16.toUSize`
This commit is contained in:
Mac Malone 2024-11-28 21:08:52 -05:00 committed by GitHub
parent 827062f807
commit 4969ec9cdb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 176 additions and 89 deletions

View file

@ -278,6 +278,16 @@ This function is overridden with a native implementation.
@[extern "lean_usize_of_nat"]
def USize.ofNat32 (n : @& Nat) (h : n < 4294967296) : USize :=
USize.ofNatCore n (Nat.lt_of_lt_of_le h le_usize_size)
@[extern "lean_uint8_to_usize"]
def UInt8.toUSize (a : UInt8) : USize :=
USize.ofNat32 a.toBitVec.toNat (Nat.lt_trans a.toBitVec.isLt (by decide))
@[extern "lean_usize_to_uint8"]
def USize.toUInt8 (a : USize) : UInt8 := a.toNat.toUInt8
@[extern "lean_uint16_to_usize"]
def UInt16.toUSize (a : UInt16) : USize :=
USize.ofNat32 a.toBitVec.toNat (Nat.lt_trans a.toBitVec.isLt (by decide))
@[extern "lean_usize_to_uint16"]
def USize.toUInt16 (a : USize) : UInt16 := a.toNat.toUInt16
@[extern "lean_uint32_to_usize"]
def UInt32.toUSize (a : UInt32) : USize := USize.ofNat32 a.toBitVec.toNat a.toBitVec.isLt
@[extern "lean_usize_to_uint32"]

View file

@ -1,7 +1,7 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
Authors: Leonardo de Moura, François G. Dorais, Mario Carneiro, Mac Malone
-/
prelude
import Init.Data.UInt.Basic
@ -9,129 +9,202 @@ import Init.Data.Fin.Lemmas
import Init.Data.BitVec.Lemmas
import Init.Data.BitVec.Bitblast
open Lean in
set_option hygiene false in
macro "declare_uint_theorems" typeName:ident : command =>
`(
namespace $typeName
macro "declare_uint_theorems" typeName:ident bits:term:arg : command => do
let mut cmds ← Syntax.getArgs <$> `(
namespace $typeName
instance : Inhabited $typeName where
default := 0
theorem zero_def : (0 : $typeName) = ⟨0⟩ := rfl
theorem one_def : (1 : $typeName) = ⟨1⟩ := rfl
theorem sub_def (a b : $typeName) : a - b = ⟨a.toBitVec - b.toBitVec⟩ := rfl
theorem mul_def (a b : $typeName) : a * b = ⟨a.toBitVec * b.toBitVec⟩ := rfl
theorem mod_def (a b : $typeName) : a % b = ⟨a.toBitVec % b.toBitVec⟩ := rfl
theorem add_def (a b : $typeName) : a + b = ⟨a.toBitVec + b.toBitVec⟩ := rfl
theorem zero_def : (0 : $typeName) = ⟨0⟩ := rfl
theorem one_def : (1 : $typeName) = ⟨1⟩ := rfl
theorem sub_def (a b : $typeName) : a - b = ⟨a.toBitVec - b.toBitVec⟩ := rfl
theorem mul_def (a b : $typeName) : a * b = ⟨a.toBitVec * b.toBitVec⟩ := rfl
theorem mod_def (a b : $typeName) : a % b = ⟨a.toBitVec % b.toBitVec⟩ := rfl
theorem add_def (a b : $typeName) : a + b = ⟨a.toBitVec + b.toBitVec⟩ := rfl
@[simp] theorem toNat_mk : (mk a).toNat = a.toNat := rfl
@[simp] theorem mk_toBitVec_eq : ∀ (a : $typeName), mk a.toBitVec = a
| ⟨_, _⟩ => rfl
@[simp] theorem toNat_ofNat {n : Nat} : (ofNat n).toNat = n % 2 ^ $bits := BitVec.toNat_ofNat ..
theorem toBitVec_eq_of_lt {a : Nat} : a < size → (ofNat a).toBitVec.toNat = a :=
Nat.mod_eq_of_lt
@[simp] theorem toNat_ofNatCore {n : Nat} {h : n < size} : (ofNatCore n h).toNat = n := BitVec.toNat_ofNatLt ..
theorem toNat_ofNat_of_lt {n : Nat} (h : n < size) : (ofNat n).toNat = n := by
rw [toNat, toBitVec_eq_of_lt h]
@[simp] theorem val_val_eq_toNat (x : $typeName) : x.val.val = x.toNat := rfl
theorem le_def {a b : $typeName} : a ≤ b ↔ a.toBitVec ≤ b.toBitVec := .rfl
theorem toNat_toBitVec_eq_toNat (x : $typeName) : x.toBitVec.toNat = x.toNat := rfl
theorem lt_def {a b : $typeName} : a < b ↔ a.toBitVec < b.toBitVec := .rfl
@[simp] theorem mk_toBitVec_eq : ∀ (a : $typeName), mk a.toBitVec = a
| ⟨_, _⟩ => rfl
@[simp] protected theorem not_le {a b : $typeName} : ¬ a ≤ b ↔ b < a := by simp [le_def, lt_def]
theorem toBitVec_eq_of_lt {a : Nat} : a < size → (ofNat a).toBitVec.toNat = a :=
Nat.mod_eq_of_lt
@[simp] protected theorem not_lt {a b : $typeName} : ¬ a < b ↔ b ≤ a := by simp [le_def, lt_def]
theorem toNat_ofNat_of_lt {n : Nat} (h : n < size) : (ofNat n).toNat = n := by
rw [toNat, toBitVec_eq_of_lt h]
@[simp] protected theorem le_refl (a : $typeName) : a ≤ a := by simp [le_def]
theorem le_def {a b : $typeName} : a ≤ b ↔ a.toBitVec ≤ b.toBitVec := .rfl
@[simp] protected theorem lt_irrefl (a : $typeName) : ¬ a < a := by simp
theorem lt_def {a b : $typeName} : a < b ↔ a.toBitVec < b.toBitVec := .rfl
protected theorem le_trans {a b c : $typeName} : a ≤ b → b ≤ c → a ≤ c := BitVec.le_trans
theorem le_iff_toNat_le {a b : $typeName} : a ≤ b ↔ a.toNat ≤ b.toNat := .rfl
protected theorem lt_trans {a b c : $typeName} : a < b → b < c → a < c := BitVec.lt_trans
theorem lt_iff_toNat_lt {a b : $typeName} : a < b ↔ a.toNat < b.toNat := .rfl
protected theorem le_total (a b : $typeName) : a ≤ b b ≤ a := BitVec.le_total ..
@[simp] protected theorem not_le {a b : $typeName} : ¬ a ≤ b ↔ b < a := by simp [le_def, lt_def]
protected theorem lt_asymm {a b : $typeName} : a < b → ¬ b < a := BitVec.lt_asymm
@[simp] protected theorem not_lt {a b : $typeName} : ¬ a < b ↔ b ≤ a := by simp [le_def, lt_def]
protected theorem toBitVec_eq_of_eq {a b : $typeName} (h : a = b) : a.toBitVec = b.toBitVec := h ▸ rfl
@[simp] protected theorem le_refl (a : $typeName) : a ≤ a := by simp [le_def]
protected theorem eq_of_toBitVec_eq {a b : $typeName} (h : a.toBitVec = b.toBitVec) : a = b := by
cases a; cases b; simp_all
@[simp] protected theorem lt_irrefl (a : $typeName) : ¬ a < a := by simp
open $typeName (eq_of_toBitVec_eq) in
protected theorem eq_of_val_eq {a b : $typeName} (h : a.val = b.val) : a = b := by
rcases a with ⟨⟨_⟩⟩; rcases b with ⟨⟨_⟩⟩; simp_all [val]
protected theorem le_trans {a b c : $typeName} : a ≤ b → b ≤ c → a ≤ c := BitVec.le_trans
open $typeName (toBitVec_eq_of_eq) in
protected theorem ne_of_toBitVec_ne {a b : $typeName} (h : a.toBitVec ≠ b.toBitVec) : a ≠ b :=
fun h' => absurd (toBitVec_eq_of_eq h') h
protected theorem lt_trans {a b c : $typeName} : a < b → b < c → a < c := BitVec.lt_trans
open $typeName (ne_of_toBitVec_ne) in
protected theorem ne_of_lt {a b : $typeName} (h : a < b) : a ≠ b := by
apply ne_of_toBitVec_ne
apply BitVec.ne_of_lt
simpa [lt_def] using h
protected theorem le_total (a b : $typeName) : a ≤ b b ≤ a := BitVec.le_total ..
@[simp] protected theorem toNat_zero : (0 : $typeName).toNat = 0 := Nat.zero_mod _
protected theorem lt_asymm {a b : $typeName} : a < b → ¬ b < a := BitVec.lt_asymm
@[simp] protected theorem toNat_mod (a b : $typeName) : (a % b).toNat = a.toNat % b.toNat := BitVec.toNat_umod ..
protected theorem toBitVec_eq_of_eq {a b : $typeName} (h : a = b) : a.toBitVec = b.toBitVec := h ▸ rfl
@[simp] protected theorem toNat_div (a b : $typeName) : (a / b).toNat = a.toNat / b.toNat := BitVec.toNat_udiv ..
protected theorem eq_of_toBitVec_eq {a b : $typeName} (h : a.toBitVec = b.toBitVec) : a = b := by
cases a; cases b; simp_all
@[simp] protected theorem toNat_sub_of_le (a b : $typeName) : b ≤ a → (a - b).toNat = a.toNat - b.toNat := BitVec.toNat_sub_of_le
open $typeName (eq_of_toBitVec_eq toBitVec_eq_of_eq) in
protected theorem toBitVec_inj {a b : $typeName} : a.toBitVec = b.toBitVec ↔ a = b :=
Iff.intro eq_of_toBitVec_eq toBitVec_eq_of_eq
protected theorem toNat_lt_size (a : $typeName) : a.toNat < size := a.toBitVec.isLt
open $typeName (eq_of_toBitVec_eq) in
protected theorem eq_of_val_eq {a b : $typeName} (h : a.val = b.val) : a = b := by
rcases a with ⟨⟨_⟩⟩; rcases b with ⟨⟨_⟩⟩; simp_all [val]
open $typeName (toNat_mod toNat_lt_size) in
protected theorem toNat_mod_lt {m : Nat} : ∀ (u : $typeName), m > 0 → toNat (u % ofNat m) < m := by
intro u h1
by_cases h2 : m < size
· rw [toNat_mod, toNat_ofNat_of_lt h2]
apply Nat.mod_lt _ h1
· apply Nat.lt_of_lt_of_le
· apply toNat_lt_size
· simpa using h2
open $typeName (eq_of_val_eq) in
protected theorem val_inj {a b : $typeName} : a.val = b.val ↔ a = b :=
Iff.intro eq_of_val_eq (congrArg val)
open $typeName (toNat_mod_lt) in
set_option linter.deprecated false in
@[deprecated toNat_mod_lt (since := "2024-09-24")]
protected theorem modn_lt {m : Nat} : ∀ (u : $typeName), m > 0 → toNat (u % m) < m := by
intro u
simp only [(· % ·)]
simp only [gt_iff_lt, toNat, modn, Fin.modn_val, BitVec.natCast_eq_ofNat, BitVec.toNat_ofNat,
Nat.reducePow]
rw [Nat.mod_eq_of_lt]
· apply Nat.mod_lt
· apply Nat.lt_of_le_of_lt
· apply Nat.mod_le
· apply Fin.is_lt
open $typeName (toBitVec_eq_of_eq) in
protected theorem ne_of_toBitVec_ne {a b : $typeName} (h : a.toBitVec ≠ b.toBitVec) : a ≠ b :=
fun h' => absurd (toBitVec_eq_of_eq h') h
protected theorem mod_lt (a : $typeName) {b : $typeName} : 0 < b → a % b < b := by
simp only [lt_def, mod_def]
apply BitVec.umod_lt
open $typeName (ne_of_toBitVec_ne) in
protected theorem ne_of_lt {a b : $typeName} (h : a < b) : a ≠ b := by
apply ne_of_toBitVec_ne
apply BitVec.ne_of_lt
simpa [lt_def] using h
protected theorem toNat.inj : ∀ {a b : $typeName}, a.toNat = b.toNat → a = b
| ⟨_, _⟩, ⟨_, _⟩, rfl => rfl
@[simp] protected theorem toNat_zero : (0 : $typeName).toNat = 0 := Nat.zero_mod _
@[simp] protected theorem ofNat_one : ofNat 1 = 1 := rfl
@[simp] protected theorem toNat_add (a b : $typeName) : (a + b).toNat = (a.toNat + b.toNat) % 2 ^ $bits := BitVec.toNat_add ..
@[simp]
theorem val_ofNat (n : Nat) : val (no_index (OfNat.ofNat n)) = OfNat.ofNat n := rfl
protected theorem toNat_sub (a b : $typeName) : (a - b).toNat = (2 ^ $bits - b.toNat + a.toNat) % 2 ^ $bits := BitVec.toNat_sub ..
@[simp]
theorem toBitVec_ofNat (n : Nat) : toBitVec (no_index (OfNat.ofNat n)) = BitVec.ofNat _ n := rfl
@[simp] protected theorem toNat_mul (a b : $typeName) : (a * b).toNat = a.toNat * b.toNat % 2 ^ $bits := BitVec.toNat_mul ..
@[simp]
theorem mk_ofNat (n : Nat) : mk (BitVec.ofNat _ n) = OfNat.ofNat n := rfl
@[simp] protected theorem toNat_mod (a b : $typeName) : (a % b).toNat = a.toNat % b.toNat := BitVec.toNat_umod ..
end $typeName
)
@[simp] protected theorem toNat_div (a b : $typeName) : (a / b).toNat = a.toNat / b.toNat := BitVec.toNat_udiv ..
declare_uint_theorems UInt8
declare_uint_theorems UInt16
declare_uint_theorems UInt32
declare_uint_theorems UInt64
declare_uint_theorems USize
@[simp] protected theorem toNat_sub_of_le (a b : $typeName) : b ≤ a → (a - b).toNat = a.toNat - b.toNat := BitVec.toNat_sub_of_le
protected theorem toNat_lt_size (a : $typeName) : a.toNat < size := a.toBitVec.isLt
open $typeName (toNat_mod toNat_lt_size) in
protected theorem toNat_mod_lt {m : Nat} : ∀ (u : $typeName), m > 0 → toNat (u % ofNat m) < m := by
intro u h1
by_cases h2 : m < size
· rw [toNat_mod, toNat_ofNat_of_lt h2]
apply Nat.mod_lt _ h1
· apply Nat.lt_of_lt_of_le
· apply toNat_lt_size
· simpa using h2
open $typeName (toNat_mod_lt) in
set_option linter.deprecated false in
@[deprecated toNat_mod_lt (since := "2024-09-24")]
protected theorem modn_lt {m : Nat} : ∀ (u : $typeName), m > 0 → toNat (u % m) < m := by
intro u
simp only [(· % ·)]
simp only [gt_iff_lt, toNat, modn, Fin.modn_val, BitVec.natCast_eq_ofNat, BitVec.toNat_ofNat,
Nat.reducePow]
rw [Nat.mod_eq_of_lt]
· apply Nat.mod_lt
· apply Nat.lt_of_le_of_lt
· apply Nat.mod_le
· apply Fin.is_lt
protected theorem mod_lt (a : $typeName) {b : $typeName} : 0 < b → a % b < b := by
simp only [lt_def, mod_def]
apply BitVec.umod_lt
protected theorem toNat.inj : ∀ {a b : $typeName}, a.toNat = b.toNat → a = b
| ⟨_, _⟩, ⟨_, _⟩, rfl => rfl
protected theorem toNat_inj : ∀ {a b : $typeName}, a.toNat = b.toNat ↔ a = b :=
Iff.intro toNat.inj (congrArg toNat)
open $typeName (toNat_inj) in
protected theorem le_antisymm_iff {a b : $typeName} : a = b ↔ a ≤ b ∧ b ≤ a :=
toNat_inj.symm.trans Nat.le_antisymm_iff
open $typeName (le_antisymm_iff) in
protected theorem le_antisymm {a b : $typeName} (h₁ : a ≤ b) (h₂ : b ≤ a) : a = b :=
le_antisymm_iff.2 ⟨h₁, h₂⟩
@[simp] protected theorem ofNat_one : ofNat 1 = 1 := rfl
@[simp] protected theorem ofNat_toNat {x : $typeName} : ofNat x.toNat = x := by
apply toNat.inj
simp [Nat.mod_eq_of_lt x.toNat_lt_size]
@[simp]
theorem val_ofNat (n : Nat) : val (no_index (OfNat.ofNat n)) = OfNat.ofNat n := rfl
@[simp]
theorem toBitVec_ofNat (n : Nat) : toBitVec (no_index (OfNat.ofNat n)) = BitVec.ofNat _ n := rfl
@[simp]
theorem mk_ofNat (n : Nat) : mk (BitVec.ofNat _ n) = OfNat.ofNat n := rfl
)
if let some nbits := bits.raw.isNatLit? then
if nbits > 8 then
cmds := cmds.push <| ←
`(@[simp] theorem toNat_toUInt8 (x : $typeName) : x.toUInt8.toNat = x.toNat % 2 ^ 8 := rfl)
if nbits < 16 then
cmds := cmds.push <| ←
`(@[simp] theorem toNat_toUInt16 (x : $typeName) : x.toUInt16.toNat = x.toNat := rfl)
else if nbits > 16 then
cmds := cmds.push <| ←
`(@[simp] theorem toNat_toUInt16 (x : $typeName) : x.toUInt16.toNat = x.toNat % 2 ^ 16 := rfl)
if nbits < 32 then
cmds := cmds.push <| ←
`(@[simp] theorem toNat_toUInt32 (x : $typeName) : x.toUInt32.toNat = x.toNat := rfl)
else if nbits > 32 then
cmds := cmds.push <| ←
`(@[simp] theorem toNat_toUInt32 (x : $typeName) : x.toUInt32.toNat = x.toNat % 2 ^ 32 := rfl)
if nbits ≤ 32 then
cmds := cmds.push <| ←
`(@[simp] theorem toNat_toUSize (x : $typeName) : x.toUSize.toNat = x.toNat := rfl)
else
cmds := cmds.push <| ←
`(@[simp] theorem toNat_toUSize (x : $typeName) : x.toUSize.toNat = x.toNat % 2 ^ System.Platform.numBits := rfl)
if nbits < 64 then
cmds := cmds.push <| ←
`(@[simp] theorem toNat_toUInt64 (x : $typeName) : x.toUInt64.toNat = x.toNat := rfl)
cmds := cmds.push <| ← `(end $typeName)
return ⟨mkNullNode cmds⟩
declare_uint_theorems UInt8 8
declare_uint_theorems UInt16 16
declare_uint_theorems UInt32 32
declare_uint_theorems UInt64 64
declare_uint_theorems USize System.Platform.numBits
@[simp] theorem USize.toNat_ofNat32 {n : Nat} {h : n < 4294967296} : (ofNat32 n h).toNat = n := rfl
@[simp] theorem USize.toNat_toUInt32 (x : USize) : x.toUInt32.toNat = x.toNat % 2 ^ 32 := rfl
@[simp] theorem USize.toNat_toUInt64 (x : USize) : x.toUInt64.toNat = x.toNat := rfl
theorem USize.toNat_ofNat_of_lt_32 {n : Nat} (h : n < 4294967296) : toNat (ofNat n) = n :=
toNat_ofNat_of_lt (Nat.lt_of_lt_of_le h le_usize_size)

View file

@ -1692,6 +1692,7 @@ static inline uint8_t lean_uint8_dec_le(uint8_t a1, uint8_t a2) { return a1 <= a
static inline uint16_t lean_uint8_to_uint16(uint8_t a) { return ((uint16_t)a); }
static inline uint32_t lean_uint8_to_uint32(uint8_t a) { return ((uint32_t)a); }
static inline uint64_t lean_uint8_to_uint64(uint8_t a) { return ((uint64_t)a); }
static inline size_t lean_uint8_to_usize(uint8_t a) { return ((size_t)a); }
/* UInt16 */
@ -1727,6 +1728,7 @@ static inline uint8_t lean_uint16_dec_le(uint16_t a1, uint16_t a2) { return a1 <
static inline uint8_t lean_uint16_to_uint8(uint16_t a) { return ((uint8_t)a); }
static inline uint32_t lean_uint16_to_uint32(uint16_t a) { return ((uint32_t)a); }
static inline uint64_t lean_uint16_to_uint64(uint16_t a) { return ((uint64_t)a); }
static inline size_t lean_uint16_to_usize(uint16_t a) { return ((size_t)a); }
/* UInt32 */
@ -1762,7 +1764,7 @@ static inline uint8_t lean_uint32_dec_le(uint32_t a1, uint32_t a2) { return a1 <
static inline uint8_t lean_uint32_to_uint8(uint32_t a) { return ((uint8_t)a); }
static inline uint16_t lean_uint32_to_uint16(uint32_t a) { return ((uint16_t)a); }
static inline uint64_t lean_uint32_to_uint64(uint32_t a) { return ((uint64_t)a); }
static inline size_t lean_uint32_to_usize(uint32_t a) { return a; }
static inline size_t lean_uint32_to_usize(uint32_t a) { return ((size_t)a); }
/* UInt64 */
@ -1834,6 +1836,8 @@ static inline uint8_t lean_usize_dec_le(size_t a1, size_t a2) { return a1 <= a2;
/* usize -> other */
static inline uint8_t lean_usize_to_uint8(size_t a) { return ((uint8_t)a); }
static inline uint16_t lean_usize_to_uint16(size_t a) { return ((uint16_t)a); }
static inline uint32_t lean_usize_to_uint32(size_t a) { return ((uint32_t)a); }
static inline uint64_t lean_usize_to_uint64(size_t a) { return ((uint64_t)a); }