chore: define udiv normal form to be /, resp. umod and % (#5645)

This follows the norm for all other Bitvector operations, and makes the
symbols `/` and `%` the simp normal form.

I'd imagine that @hargonix would prefer that this be merged after
https://github.com/leanprover/lean4/pull/5628, so as to prevent churn
for his PR. I'm happy to rebase the PR once the other PR lands.

---------

Co-authored-by: Henrik Böving <hargonix@gmail.com>
This commit is contained in:
Siddharth 2024-10-08 03:49:46 -05:00 committed by GitHub
parent 4415a81f35
commit 2ed7924bae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 21 additions and 18 deletions

View file

@ -718,6 +718,8 @@ section normalization_eqs
@[simp] theorem add_eq (x y : BitVec w) : BitVec.add x y = x + y := rfl
@[simp] theorem sub_eq (x y : BitVec w) : BitVec.sub x y = x - y := rfl
@[simp] theorem mul_eq (x y : BitVec w) : BitVec.mul x y = x * y := rfl
@[simp] theorem udiv_eq (x y : BitVec w) : BitVec.udiv x y = x / y := rfl
@[simp] theorem umod_eq (x y : BitVec w) : BitVec.umod x y = x % y := rfl
@[simp] theorem zero_eq : BitVec.zero n = 0#n := rfl
end normalization_eqs

View file

@ -571,7 +571,7 @@ then `n.udiv d = q`. -/
theorem udiv_eq_of_mul_add_toNat {d n q r : BitVec w} (hd : 0 < d)
(hrd : r < d)
(hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) :
n.udiv d = q := by
n / d = q := by
apply BitVec.eq_of_toNat_eq
rw [toNat_udiv]
replace hdqnr : (d.toNat * q.toNat + r.toNat) / d.toNat = n.toNat / d.toNat := by
@ -587,7 +587,7 @@ theorem udiv_eq_of_mul_add_toNat {d n q r : BitVec w} (hd : 0 < d)
then `n.umod d = r`. -/
theorem umod_eq_of_mul_add_toNat {d n q r : BitVec w} (hrd : r < d)
(hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) :
n.umod d = r := by
n % d = r := by
apply BitVec.eq_of_toNat_eq
rw [toNat_umod]
replace hdqnr : (d.toNat * q.toNat + r.toNat) % d.toNat = n.toNat % d.toNat := by
@ -688,7 +688,7 @@ quotient has been correctly computed.
theorem DivModState.udiv_eq_of_lawful {n d : BitVec w} {qr : DivModState w}
(h_lawful : DivModState.Lawful {n, d} qr)
(h_final : qr.wn = 0) :
n.udiv d = qr.q := by
n / d = qr.q := by
apply udiv_eq_of_mul_add_toNat h_lawful.hdPos h_lawful.hrLtDivisor
have hdiv := h_lawful.hdiv
simp only [h_final] at *
@ -701,7 +701,7 @@ remainder has been correctly computed.
theorem DivModState.umod_eq_of_lawful {qr : DivModState w}
(h : DivModState.Lawful {n, d} qr)
(h_final : qr.wn = 0) :
n.umod d = qr.r := by
n % d = qr.r := by
apply umod_eq_of_mul_add_toNat h.hrLtDivisor
have hdiv := h.hdiv
simp only [shiftRight_zero] at hdiv
@ -875,7 +875,7 @@ theorem wn_divRec (args : DivModArgs w) (qr : DivModState w) :
/-- The result of `udiv` agrees with the result of the division recurrence. -/
theorem udiv_eq_divRec (hd : 0#w < d) :
let out := divRec w {n, d} (DivModState.init w)
n.udiv d = out.q := by
n / d = out.q := by
have := DivModState.lawful_init {n, d} hd
have := lawful_divRec this
apply DivModState.udiv_eq_of_lawful this (wn_divRec ..)
@ -883,7 +883,7 @@ theorem udiv_eq_divRec (hd : 0#w < d) :
/-- The result of `umod` agrees with the result of the division recurrence. -/
theorem umod_eq_divRec (hd : 0#w < d) :
let out := divRec w {n, d} (DivModState.init w)
n.umod d = out.r := by
n % d = out.r := by
have := DivModState.lawful_init {n, d} hd
have := lawful_divRec this
apply DivModState.umod_eq_of_lawful this (wn_divRec ..)

View file

@ -1339,37 +1339,38 @@ theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNa
/-! ### udiv -/
theorem udiv_eq {x y : BitVec n} : x.udiv y = BitVec.ofNat n (x.toNat / y.toNat) := by
theorem udiv_def {x y : BitVec n} : x / y = BitVec.ofNat n (x.toNat / y.toNat) := by
have h : x.toNat / y.toNat < 2 ^ n := Nat.lt_of_le_of_lt (Nat.div_le_self ..) (by omega)
rw [← udiv_eq]
simp [udiv, bv_toNat, h, Nat.mod_eq_of_lt]
@[simp, bv_toNat]
theorem toNat_udiv {x y : BitVec n} : (x.udiv y).toNat = x.toNat / y.toNat := by
simp only [udiv_eq]
theorem toNat_udiv {x y : BitVec n} : (x / y).toNat = x.toNat / y.toNat := by
rw [udiv_def]
by_cases h : y = 0
· simp [h]
· rw [toNat_ofNat, Nat.mod_eq_of_lt]
exact Nat.lt_of_le_of_lt (Nat.div_le_self ..) (by omega)
@[simp]
theorem udiv_zero {x : BitVec n} : x.udiv 0#n = 0#n := by
simp only [udiv, toNat_ofNat, Nat.zero_mod, Nat.div_zero]
rfl
theorem udiv_zero {x : BitVec n} : x / 0#n = 0#n := by
simp [udiv_def]
/-! ### umod -/
theorem umod_eq {x y : BitVec n} :
x.umod y = BitVec.ofNat n (x.toNat % y.toNat) := by
theorem umod_def {x y : BitVec n} :
x % y = BitVec.ofNat n (x.toNat % y.toNat) := by
rw [← umod_eq]
have h : x.toNat % y.toNat < 2 ^ n := Nat.lt_of_le_of_lt (Nat.mod_le _ _) x.isLt
simp [umod, bv_toNat, Nat.mod_eq_of_lt h]
@[simp, bv_toNat]
theorem toNat_umod {x y : BitVec n} :
(x.umod y).toNat = x.toNat % y.toNat := rfl
(x % y).toNat = x.toNat % y.toNat := rfl
@[simp]
theorem umod_zero {x : BitVec n} : x.umod 0#n = x := by
simp [umod_eq]
theorem umod_zero {x : BitVec n} : x % 0#n = x := by
simp [umod_def]
/-! ### sdiv -/
@ -2203,7 +2204,7 @@ protected theorem ne_of_lt {x y : BitVec n} : x < y → x ≠ y := by
simp only [lt_def, ne_eq, toNat_eq]
apply Nat.ne_of_lt
protected theorem umod_lt (x : BitVec n) {y : BitVec n} : 0 < y → x.umod y < y := by
protected theorem umod_lt (x : BitVec n) {y : BitVec n} : 0 < y → x % y < y := by
simp only [ofNat_eq_ofNat, lt_def, toNat_ofNat, Nat.zero_mod, umod, toNat_ofNatLt]
apply Nat.mod_lt