diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 2b668aacca..c383a2bee3 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -21,6 +21,9 @@ set_option linter.missingDocs true namespace BitVec +@[simp] theorem mk_zero : BitVec.ofFin (w := w) ⟨0, h⟩ = 0#w := rfl +@[simp] theorem ofNatLT_zero : BitVec.ofNatLT (w := w) 0 h = 0#w := rfl + @[simp] theorem getLsbD_ofFin (x : Fin (2^n)) (i : Nat) : getLsbD (BitVec.ofFin x) i = x.val.testBit i := rfl @@ -136,6 +139,8 @@ theorem toNat_ne_iff_ne {n} {x y : BitVec n} : x.toNat ≠ y.toNat ↔ x ≠ y : @[bitvec_to_nat] theorem toNat_eq {x y : BitVec n} : x = y ↔ x.toNat = y.toNat := Iff.intro (congrArg BitVec.toNat) eq_of_toNat_eq +theorem toNat_inj {x y : BitVec n} : x.toNat = y.toNat ↔ x = y := toNat_eq.symm + @[bitvec_to_nat] theorem toNat_ne {x y : BitVec n} : x ≠ y ↔ x.toNat ≠ y.toNat := by rw [Ne, toNat_eq] @@ -637,7 +642,7 @@ theorem toInt_nonneg_of_msb_false {x : BitVec w} (h : x.msb = false) : 0 ≤ x.t apply Nat.mod_eq_of_lt apply Nat.one_lt_two_pow (by omega) -/-- Prove equality of bitvectors in terms of nat operations. -/ +/-- Prove equality of bitvectors in terms of integer operations. -/ theorem eq_of_toInt_eq {x y : BitVec n} : x.toInt = y.toInt → x = y := by intro eq simp only [toInt_eq_toNat_cond] at eq diff --git a/src/Init/Data/Int/Pow.lean b/src/Init/Data/Int/Pow.lean index f3463f0e37..0b2f3aa699 100644 --- a/src/Init/Data/Int/Pow.lean +++ b/src/Init/Data/Int/Pow.lean @@ -67,4 +67,10 @@ theorem pow_lt_pow_of_lt {a : Int} {b c : Nat} (ha : 1 < a) (hbc : b < c): | 0 => rfl | k + 1 => by rw [Int.pow_succ, natAbs_mul, natAbs_pow, Nat.pow_succ] +theorem toNat_pow_of_nonneg {x : Int} (h : 0 ≤ x) (k : Nat) : (x ^ k).toNat = x.toNat ^ k := by + induction k with + | zero => simp + | succ k ih => + rw [Int.pow_succ, Int.toNat_mul (Int.pow_nonneg h) h, ih, Nat.pow_succ] + end Int diff --git a/src/Init/Grind/CommRing/Basic.lean b/src/Init/Grind/CommRing/Basic.lean index 2d77d77407..5440d4c873 100644 --- a/src/Init/Grind/CommRing/Basic.lean +++ b/src/Init/Grind/CommRing/Basic.lean @@ -30,6 +30,7 @@ namespace Lean.Grind class CommRing (α : Type u) extends Add α, Mul α, Neg α, Sub α, HPow α Nat α where [ofNat : ∀ n, OfNat α n] + [intCast : IntCast α] add_assoc : ∀ a b c : α, a + b + c = a + (b + c) add_comm : ∀ a b : α, a + b = b + a add_zero : ∀ a : α, a + 0 = a @@ -43,6 +44,8 @@ class CommRing (α : Type u) extends Add α, Mul α, Neg α, Sub α, HPow α Nat pow_zero : ∀ a : α, a ^ 0 = 1 pow_succ : ∀ a : α, ∀ n : Nat, a ^ (n + 1) = (a ^ n) * a ofNat_succ : ∀ a : Nat, OfNat.ofNat (α := α) (a + 1) = OfNat.ofNat a + 1 := by intros; rfl + intCast_ofNat : ∀ n : Nat, Int.cast (OfNat.ofNat (α := Int) n) = OfNat.ofNat (α := α) n := by intros; rfl + intCast_neg : ∀ i : Int, Int.cast (R := α) (-i) = -Int.cast i := by intros; rfl -- We reduce the priority of these parent instances, -- so that in downstream libraries with their own `CommRing` class, @@ -53,6 +56,9 @@ attribute [instance 100] CommRing.toAdd CommRing.toMul CommRing.toNeg CommRing.t -- This is a low-priority instance, to avoid conflicts with existing `OfNat` instances. attribute [instance 100] CommRing.ofNat +-- This is a low-priority instance, to avoid conflicts with existing `IntCast` instances. +attribute [instance 100] CommRing.intCast + namespace CommRing variable {α : Type u} [CommRing α] @@ -69,6 +75,7 @@ theorem ofNat_add (a b : Nat) : OfNat.ofNat (α := α) (a + b) = OfNat.ofNat a + | succ b ih => rw [Nat.add_succ, ofNat_succ, ih, ofNat_succ b, add_assoc] theorem natCast_succ (n : Nat) : ((n + 1 : Nat) : α) = ((n : α) + 1) := ofNat_add _ _ +theorem natCast_add (a b : Nat) : ((a + b : Nat) : α) = ((a : α) + (b : α)) := ofNat_add _ _ theorem zero_add (a : α) : 0 + a = a := by rw [add_comm, add_zero] @@ -96,6 +103,9 @@ theorem ofNat_mul (a b : Nat) : OfNat.ofNat (α := α) (a * b) = OfNat.ofNat a * | zero => simp [Nat.mul_zero, mul_zero] | succ a ih => rw [Nat.mul_succ, ofNat_add, ih, ofNat_add, left_distrib, mul_one] +theorem natCast_mul (a b : Nat) : ((a * b : Nat) : α) = ((a : α) * (b : α)) := by + rw [← ofNat_eq_natCast, ofNat_mul, ofNat_eq_natCast, ofNat_eq_natCast] + theorem add_left_inj {a b : α} (c : α) : a + c = b + c ↔ a = b := ⟨fun h => by simpa [add_assoc, add_neg_cancel, add_zero] using (congrArg (· + -c) h), fun g => congrArg (· + c) g⟩ @@ -134,25 +144,16 @@ theorem sub_eq_iff {a b c : α} : a - b = c ↔ a = c + b := by theorem sub_eq_zero_iff {a b : α} : a - b = 0 ↔ a = b := by simp [sub_eq_iff, zero_add] -instance intCastInst : IntCast α where - intCast n := match n with - | Int.ofNat n => OfNat.ofNat n - | Int.negSucc n => -OfNat.ofNat (n + 1) - -theorem intCast_zero : ((0 : Int) : α) = 0 := rfl -theorem intCast_one : ((1 : Int) : α) = 1 := rfl -theorem intCast_neg_one : ((-1 : Int) : α) = -1 := rfl -theorem intCast_ofNat (n : Nat) : ((n : Int) : α) = (n : α) := rfl -theorem intCast_ofNat_add_one (n : Nat) : ((n + 1 : Int) : α) = (n : α) + 1 := ofNat_add _ _ -theorem intCast_negSucc (n : Nat) : ((-(n + 1) : Int) : α) = -((n : α) + 1) := congrArg (- ·) (ofNat_add _ _) -theorem intCast_neg (x : Int) : ((-x : Int) : α) = - (x : α) := - match x with - | (0 : Nat) => neg_zero.symm - | (n + 1 : Nat) => by - rw [Int.natCast_add, Int.cast_ofNat_Int, intCast_negSucc, intCast_ofNat_add_one] - | -((n : Nat) + 1) => by - rw [Int.neg_neg, intCast_ofNat_add_one, intCast_negSucc, neg_neg] -theorem intCast_nat_add {x y : Nat} : ((x + y : Int) : α) = ((x : α) + (y : α)) := ofNat_add _ _ +theorem intCast_zero : ((0 : Int) : α) = 0 := intCast_ofNat 0 +theorem intCast_one : ((1 : Int) : α) = 1 := intCast_ofNat 1 +theorem intCast_neg_one : ((-1 : Int) : α) = -1 := by rw [intCast_neg, intCast_ofNat] +theorem intCast_natCast (n : Nat) : ((n : Int) : α) = (n : α) := intCast_ofNat n +theorem intCast_natCast_add_one (n : Nat) : ((n + 1 : Int) : α) = (n : α) + 1 := by + rw [← Int.natCast_succ, intCast_natCast, natCast_add, ofNat_eq_natCast] +theorem intCast_negSucc (n : Nat) : ((-(n + 1) : Int) : α) = -((n : α) + 1) := by + rw [intCast_neg, ← Int.natCast_succ, intCast_natCast, ofNat_eq_natCast, natCast_add] +theorem intCast_nat_add {x y : Nat} : ((x + y : Int) : α) = ((x : α) + (y : α)) := by + rw [Int.ofNat_add_ofNat, intCast_natCast, natCast_add] theorem intCast_nat_sub {x y : Nat} (h : x ≥ y) : (((x - y : Nat) : Int) : α) = ((x : α) - (y : α)) := by induction x with | zero => @@ -162,29 +163,30 @@ theorem intCast_nat_sub {x y : Nat} (h : x ≥ y) : (((x - y : Nat) : Int) : α) by_cases h : x + 1 = y · simp [h, intCast_zero, sub_self] · have : ((x + 1 - y : Nat) : Int) = (x - y : Nat) + 1 := by omega - rw [this, intCast_ofNat_add_one] + rw [this, intCast_natCast_add_one] specialize ih (by omega) - rw [intCast_ofNat] at ih + rw [intCast_natCast] at ih rw [ih, natCast_succ, sub_eq_add_neg, sub_eq_add_neg, add_assoc, add_comm _ 1, ← add_assoc] theorem intCast_add (x y : Int) : ((x + y : Int) : α) = ((x : α) + (y : α)) := match x, y with - | (x : Nat), (y : Nat) => ofNat_add _ _ + | (x : Nat), (y : Nat) => by + rw [intCast_nat_add, intCast_natCast, intCast_natCast] | (x : Nat), (-(y + 1 : Nat)) => by by_cases h : x ≥ y + 1 · have : (x + -(y+1 : Nat) : Int) = ((x - (y + 1) : Nat) : Int) := by omega - rw [this, intCast_neg, intCast_nat_sub h, intCast_ofNat, intCast_ofNat, sub_eq_add_neg] + rw [this, intCast_neg, intCast_nat_sub h, intCast_natCast, intCast_natCast, sub_eq_add_neg] · have : (x + -(y+1 : Nat) : Int) = (-(y + 1 - x : Nat) : Int) := by omega - rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_ofNat, intCast_neg, intCast_ofNat, + rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_natCast, intCast_neg, intCast_natCast, neg_sub, sub_eq_add_neg] | (-(x + 1 : Nat)), (y : Nat) => by by_cases h : y ≥ x+ 1 · have : (-(x+1 : Nat) + y : Int) = ((y - (x + 1) : Nat) : Int) := by omega - rw [this, intCast_neg, intCast_nat_sub h, intCast_ofNat, intCast_ofNat, sub_eq_add_neg, add_comm] + rw [this, intCast_neg, intCast_nat_sub h, intCast_natCast, intCast_natCast, sub_eq_add_neg, add_comm] · have : (-(x+1 : Nat) + y : Int) = (-(x + 1 - y : Nat) : Int) := by omega - rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_ofNat, intCast_neg, intCast_ofNat, + rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_natCast, intCast_neg, intCast_natCast, neg_sub, sub_eq_add_neg, add_comm] | (-(x + 1 : Nat)), (-(y + 1 : Nat)) => by - rw [← Int.neg_add, intCast_neg, intCast_nat_add, neg_add, intCast_neg, intCast_neg, intCast_ofNat, intCast_ofNat] + rw [← Int.neg_add, intCast_neg, intCast_nat_add, neg_add, intCast_neg, intCast_neg, intCast_natCast, intCast_natCast] theorem intCast_sub (x y : Int) : ((x - y : Int) : α) = ((x : α) - (y : α)) := by rw [Int.sub_eq_add_neg, intCast_add, intCast_neg, sub_eq_add_neg] @@ -200,17 +202,20 @@ theorem neg_mul (a b : α) : (-a) * b = -(a * b) := by theorem mul_neg (a b : α) : a * (-b) = -(a * b) := by rw [mul_comm, neg_mul, mul_comm] -theorem intCast_nat_mul (x y : Nat) : ((x * y : Int) : α) = ((x : α) * (y : α)) := ofNat_mul _ _ +theorem intCast_nat_mul (x y : Nat) : ((x * y : Int) : α) = ((x : α) * (y : α)) := by + rw [Int.ofNat_mul_ofNat, intCast_natCast, natCast_mul] + theorem intCast_mul (x y : Int) : ((x * y : Int) : α) = ((x : α) * (y : α)) := match x, y with - | (x : Nat), (y : Nat) => ofNat_mul _ _ + | (x : Nat), (y : Nat) => by + rw [intCast_nat_mul, intCast_natCast, intCast_natCast] | (x : Nat), (-(y + 1 : Nat)) => by - rw [Int.mul_neg, intCast_neg, intCast_nat_mul, intCast_neg, mul_neg, intCast_ofNat, intCast_ofNat] + rw [Int.mul_neg, intCast_neg, intCast_nat_mul, intCast_neg, mul_neg, intCast_natCast, intCast_natCast] | (-(x + 1 : Nat)), (y : Nat) => by - rw [Int.neg_mul, intCast_neg, intCast_nat_mul, intCast_neg, neg_mul, intCast_ofNat, intCast_ofNat] + rw [Int.neg_mul, intCast_neg, intCast_nat_mul, intCast_neg, neg_mul, intCast_natCast, intCast_natCast] | (-(x + 1 : Nat)), (-(y + 1 : Nat)) => by rw [Int.neg_mul_neg, intCast_neg, intCast_neg, neg_mul, mul_neg, neg_neg, intCast_nat_mul, - intCast_ofNat, intCast_ofNat] + intCast_natCast, intCast_natCast] theorem intCast_pow (x : Int) (k : Nat) : ((x ^ k : Int) : α) = (x : α) ^ k := by induction k @@ -240,10 +245,10 @@ theorem intCast_eq_zero_iff (x : Int) : (x : α) = 0 ↔ x % p = 0 := match x with | (x : Nat) => by have := ofNat_eq_zero_iff (α := α) p (x := x) - rw [Int.ofNat_mod_ofNat] + rw [Int.ofNat_mod_ofNat, intCast_natCast] norm_cast | -(x + 1 : Nat) => by - rw [Int.neg_emod, Int.ofNat_mod_ofNat, intCast_neg, intCast_ofNat, neg_eq_zero] + rw [Int.neg_emod, Int.ofNat_mod_ofNat, intCast_neg, intCast_natCast, neg_eq_zero] have := ofNat_eq_zero_iff (α := α) p (x := x + 1) rw [ofNat_eq_natCast] at this rw [this] @@ -273,7 +278,7 @@ theorem intCast_ext_iff {x y : Int} : (x : α) = (y : α) ↔ x % p = y % p := b theorem ofNat_ext_iff {x y : Nat} : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y ↔ x % p = y % p := by have := intCast_ext_iff (α := α) p (x := x) (y := y) - simp only [intCast_ofNat, ← Int.ofNat_emod] at this + simp only [intCast_natCast, ← Int.ofNat_emod] at this simp only [ofNat_eq_natCast] norm_cast at this @@ -288,7 +293,7 @@ theorem intCast_emod (x : Int) : ((x % p : Int) : α) = (x : α) := by rw [intCast_ext_iff p, Int.emod_emod] theorem natCast_emod (x : Nat) : ((x % p : Nat) : α) = (x : α) := by - simp only [← intCast_ofNat] + simp only [← intCast_natCast] rw [Int.ofNat_emod, intCast_emod] theorem ofNat_emod (x : Nat) : OfNat.ofNat (α := α) (x % p) = OfNat.ofNat x := diff --git a/src/Init/Grind/CommRing/BitVec.lean b/src/Init/Grind/CommRing/BitVec.lean index 85cb0590b2..0af5910691 100644 --- a/src/Init/Grind/CommRing/BitVec.lean +++ b/src/Init/Grind/CommRing/BitVec.lean @@ -23,6 +23,7 @@ instance : CommRing (BitVec w) where pow_zero _ := BitVec.pow_zero pow_succ _ _ := BitVec.pow_succ ofNat_succ x := BitVec.ofNat_add x 1 + intCast_neg _ := BitVec.ofInt_neg instance : IsCharP (BitVec w) (2 ^ w) where ofNat_eq_zero_iff {x} := by simp [BitVec.ofInt, BitVec.toNat_eq] diff --git a/src/Init/Grind/CommRing/Poly.lean b/src/Init/Grind/CommRing/Poly.lean index 2f4673add3..b9e74b1cd2 100644 --- a/src/Init/Grind/CommRing/Poly.lean +++ b/src/Init/Grind/CommRing/Poly.lean @@ -30,12 +30,18 @@ abbrev Context (α : Type u) := RArray α def Var.denote {α} (ctx : Context α) (v : Var) : α := ctx.get v +def denoteInt {α} [CommRing α] (k : Int) : α := + bif k < 0 then + - OfNat.ofNat (α := α) k.natAbs + else + OfNat.ofNat (α := α) k.natAbs + def Expr.denote {α} [CommRing α] (ctx : Context α) : Expr → α | .add a b => denote ctx a + denote ctx b | .sub a b => denote ctx a - denote ctx b | .mul a b => denote ctx a * denote ctx b | .neg a => -denote ctx a - | .num k => k + | .num k => denoteInt k | .var v => v.denote ctx | .pow a k => denote ctx a ^ k @@ -498,6 +504,11 @@ def NullCert.toPolyC (nc : NullCert) (c : Nat) : Poly := Theorems for justifying the procedure for commutative rings in `grind`. -/ +theorem denoteInt_eq {α} [CommRing α] (k : Int) : denoteInt (α := α) k = k := by + simp [denoteInt, cond_eq_if] <;> split + next h => rw [ofNat_eq_natCast, ← intCast_natCast, ← intCast_neg, ← Int.eq_neg_natAbs_of_nonpos (Int.le_of_lt h)] + next h => rw [ofNat_eq_natCast, ← intCast_natCast, ← Int.eq_natAbs_of_nonneg (Int.le_of_not_gt h)] + theorem Power.denote_eq {α} [CommRing α] (ctx : Context α) (p : Power) : p.denote ctx = p.x.denote ctx ^ p.k := by cases p <;> simp [Power.denote] <;> split <;> simp [pow_zero, pow_succ, one_mul] @@ -677,7 +688,7 @@ theorem Expr.denote_toPoly {α} [CommRing α] (ctx : Context α) (e : Expr) fun_induction toPoly <;> simp [toPoly, denote, Poly.denote, Poly.denote_ofVar, Poly.denote_combine, Poly.denote_mul, Poly.denote_mulConst, Poly.denote_pow, intCast_pow, intCast_neg, intCast_one, - neg_mul, one_mul, sub_eq_add_neg, *] + neg_mul, one_mul, sub_eq_add_neg, denoteInt_eq, *] next => simp [Poly.denote_ofMon, Mon.denote, Power.denote_eq, mul_one] theorem Expr.eq_of_toPoly_eq {α} [CommRing α] (ctx : Context α) (a b : Expr) (h : a.toPoly == b.toPoly) : a.denote ctx = b.denote ctx := by @@ -845,7 +856,7 @@ theorem Expr.denote_toPolyC {α c} [CommRing α] [IsCharP α c] (ctx : Context unfold toPolyC fun_induction toPolyC.go <;> simp [toPolyC.go, denote, Poly.denote, Poly.denote_ofVar, Poly.denote_combineC, - Poly.denote_mulC, Poly.denote_mulConstC, Poly.denote_powC, *] + Poly.denote_mulC, Poly.denote_mulConstC, Poly.denote_powC, denoteInt_eq, *] next => rw [IsCharP.intCast_emod] next => rw [intCast_neg, neg_mul, intCast_one, one_mul] next => rw [intCast_neg, neg_mul, intCast_one, one_mul, sub_eq_add_neg] diff --git a/src/Init/Grind/CommRing/SInt.lean b/src/Init/Grind/CommRing/SInt.lean index b56da453cd..c8d8881450 100644 --- a/src/Init/Grind/CommRing/SInt.lean +++ b/src/Init/Grind/CommRing/SInt.lean @@ -26,6 +26,7 @@ instance : CommRing Int8 where pow_zero := Int8.pow_zero pow_succ := Int8.pow_succ ofNat_succ x := Int8.ofNat_add x 1 + intCast_neg := Int8.ofInt_neg instance : IsCharP Int8 (2 ^ 8) where ofNat_eq_zero_iff {x} := by @@ -51,7 +52,7 @@ instance : CommRing Int16 where pow_zero := Int16.pow_zero pow_succ := Int16.pow_succ ofNat_succ x := Int16.ofNat_add x 1 - + intCast_neg := Int16.ofInt_neg instance : IsCharP Int16 (2 ^ 16) where ofNat_eq_zero_iff {x} := by have : OfNat.ofNat x = Int16.ofInt x := rfl @@ -76,7 +77,7 @@ instance : CommRing Int32 where pow_zero := Int32.pow_zero pow_succ := Int32.pow_succ ofNat_succ x := Int32.ofNat_add x 1 - + intCast_neg := Int32.ofInt_neg instance : IsCharP Int32 (2 ^ 32) where ofNat_eq_zero_iff {x} := by have : OfNat.ofNat x = Int32.ofInt x := rfl @@ -101,7 +102,7 @@ instance : CommRing Int64 where pow_zero := Int64.pow_zero pow_succ := Int64.pow_succ ofNat_succ x := Int64.ofNat_add x 1 - + intCast_neg := Int64.ofInt_neg instance : IsCharP Int64 (2 ^ 64) where ofNat_eq_zero_iff {x} := by have : OfNat.ofNat x = Int64.ofInt x := rfl @@ -126,7 +127,7 @@ instance : CommRing ISize where pow_zero := ISize.pow_zero pow_succ := ISize.pow_succ ofNat_succ x := ISize.ofNat_add x 1 - + intCast_neg := ISize.ofInt_neg open System.Platform (numBits) instance : IsCharP ISize (2 ^ numBits) where diff --git a/src/Init/Grind/CommRing/UInt.lean b/src/Init/Grind/CommRing/UInt.lean index 5385c65710..0777e83351 100644 --- a/src/Init/Grind/CommRing/UInt.lean +++ b/src/Init/Grind/CommRing/UInt.lean @@ -7,7 +7,6 @@ prelude import Init.Grind.CommRing.Basic import Init.Data.UInt.Lemmas - namespace UInt8 /-- Variant of `UInt8.ofNat_mod_size` replacing `2 ^ 8` with `256`.-/ @@ -16,6 +15,16 @@ theorem ofNat_mod_size' : ofNat (x % 256) = ofNat x := ofNat_mod_size instance : IntCast UInt8 where intCast x := UInt8.ofInt x +theorem intCast_ofNat (x : Nat) : (OfNat.ofNat (α := Int) x : UInt8) = OfNat.ofNat x := by + -- A better proof would be welcome! + simp only [Int.cast, IntCast.intCast] + rw [UInt8.ofInt] + rw [Int.toNat_emod (Int.zero_le_ofNat x) (by decide)] + erw [Int.toNat_natCast] + rw [Int.toNat_pow_of_nonneg (by decide)] + simp only [ofNat, BitVec.ofNat, Fin.ofNat', Int.reduceToNat, Nat.dvd_refl, + Nat.mod_mod_of_dvd, instOfNat] + end UInt8 namespace UInt16 @@ -26,6 +35,16 @@ theorem ofNat_mod_size' : ofNat (x % 65536) = ofNat x := ofNat_mod_size instance : IntCast UInt16 where intCast x := UInt16.ofInt x +theorem intCast_ofNat (x : Nat) : (OfNat.ofNat (α := Int) x : UInt16) = OfNat.ofNat x := by + -- A better proof would be welcome! + simp only [Int.cast, IntCast.intCast] + rw [UInt16.ofInt] + rw [Int.toNat_emod (Int.zero_le_ofNat x) (by decide)] + erw [Int.toNat_natCast] + rw [Int.toNat_pow_of_nonneg (by decide)] + simp only [ofNat, BitVec.ofNat, Fin.ofNat', Int.reduceToNat, Nat.dvd_refl, + Nat.mod_mod_of_dvd, instOfNat] + end UInt16 namespace UInt32 @@ -36,6 +55,16 @@ theorem ofNat_mod_size' : ofNat (x % 4294967296) = ofNat x := ofNat_mod_size instance : IntCast UInt32 where intCast x := UInt32.ofInt x +theorem intCast_ofNat (x : Nat) : (OfNat.ofNat (α := Int) x : UInt32) = OfNat.ofNat x := by + -- A better proof would be welcome! + simp only [Int.cast, IntCast.intCast] + rw [UInt32.ofInt] + rw [Int.toNat_emod (Int.zero_le_ofNat x) (by decide)] + erw [Int.toNat_natCast] + rw [Int.toNat_pow_of_nonneg (by decide)] + simp only [ofNat, BitVec.ofNat, Fin.ofNat', Int.reduceToNat, Nat.dvd_refl, + Nat.mod_mod_of_dvd, instOfNat] + end UInt32 namespace UInt64 @@ -46,6 +75,16 @@ theorem ofNat_mod_size' : ofNat (x % 18446744073709551616) = ofNat x := ofNat_mo instance : IntCast UInt64 where intCast x := UInt64.ofInt x +theorem intCast_ofNat (x : Nat) : (OfNat.ofNat (α := Int) x : UInt64) = OfNat.ofNat x := by + -- A better proof would be welcome! + simp only [Int.cast, IntCast.intCast] + rw [UInt64.ofInt] + rw [Int.toNat_emod (Int.zero_le_ofNat x) (by decide)] + erw [Int.toNat_natCast] + rw [Int.toNat_pow_of_nonneg (by decide)] + simp only [ofNat, BitVec.ofNat, Fin.ofNat', Int.reduceToNat, Nat.dvd_refl, + Nat.mod_mod_of_dvd, instOfNat] + end UInt64 namespace USize @@ -53,9 +92,21 @@ namespace USize instance : IntCast USize where intCast x := USize.ofInt x +theorem intCast_ofNat (x : Nat) : (OfNat.ofNat (α := Int) x : USize) = OfNat.ofNat x := by + -- A better proof would be welcome! + simp only [Int.cast, IntCast.intCast] + rw [USize.ofInt] + rw [Int.toNat_emod (Int.zero_le_ofNat x)] + · erw [Int.toNat_natCast] + rw [Int.toNat_pow_of_nonneg (by decide)] + simp only [ofNat, BitVec.ofNat, Fin.ofNat', Int.reduceToNat, Nat.dvd_refl, + Nat.mod_mod_of_dvd, instOfNat] + · obtain _ | _ := System.Platform.numBits_eq <;> simp_all + end USize namespace Lean.Grind + instance : CommRing UInt8 where add_assoc := UInt8.add_assoc add_comm := UInt8.add_comm @@ -70,6 +121,8 @@ instance : CommRing UInt8 where pow_zero := UInt8.pow_zero pow_succ := UInt8.pow_succ ofNat_succ x := UInt8.ofNat_add x 1 + intCast_neg := UInt8.ofInt_neg + intCast_ofNat := UInt8.intCast_ofNat instance : IsCharP UInt8 256 where ofNat_eq_zero_iff {x} := by @@ -90,6 +143,8 @@ instance : CommRing UInt16 where pow_zero := UInt16.pow_zero pow_succ := UInt16.pow_succ ofNat_succ x := UInt16.ofNat_add x 1 + intCast_neg := UInt16.ofInt_neg + intCast_ofNat := UInt16.intCast_ofNat instance : IsCharP UInt16 65536 where ofNat_eq_zero_iff {x} := by @@ -110,6 +165,8 @@ instance : CommRing UInt32 where pow_zero := UInt32.pow_zero pow_succ := UInt32.pow_succ ofNat_succ x := UInt32.ofNat_add x 1 + intCast_neg := UInt32.ofInt_neg + intCast_ofNat := UInt32.intCast_ofNat instance : IsCharP UInt32 4294967296 where ofNat_eq_zero_iff {x} := by @@ -130,6 +187,8 @@ instance : CommRing UInt64 where pow_zero := UInt64.pow_zero pow_succ := UInt64.pow_succ ofNat_succ x := UInt64.ofNat_add x 1 + intCast_neg := UInt64.ofInt_neg + intCast_ofNat := UInt64.intCast_ofNat instance : IsCharP UInt64 18446744073709551616 where ofNat_eq_zero_iff {x} := by @@ -150,6 +209,8 @@ instance : CommRing USize where pow_zero := USize.pow_zero pow_succ := USize.pow_succ ofNat_succ x := USize.ofNat_add x 1 + intCast_neg := USize.ofInt_neg + intCast_ofNat := USize.intCast_ofNat open System.Platform diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index 05b7cea128..6784e6d7df 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -2037,23 +2037,23 @@ structure BitVec (w : Nat) where /-- Bitvectors have decidable equality. -This should be used via the instance `DecidableEq (BitVec n)`. +This should be used via the instance `DecidableEq (BitVec w)`. -/ -- We manually derive the `DecidableEq` instances for `BitVec` because -- we want to have builtin support for bit-vector literals, and we -- need a name for this function to implement `canUnfoldAtMatcher` at `WHNF.lean`. -def BitVec.decEq (x y : BitVec n) : Decidable (Eq x y) := +def BitVec.decEq (x y : BitVec w) : Decidable (Eq x y) := match x, y with | ⟨n⟩, ⟨m⟩ => dite (Eq n m) (fun h => isTrue (h ▸ rfl)) (fun h => isFalse (fun h' => BitVec.noConfusion h' (fun h' => absurd h' h))) -instance : DecidableEq (BitVec n) := BitVec.decEq +instance : DecidableEq (BitVec w) := BitVec.decEq -/-- The `BitVec` with value `i`, given a proof that `i < 2^n`. -/ +/-- The `BitVec` with value `i`, given a proof that `i < 2^w`. -/ @[match_pattern] -protected def BitVec.ofNatLT {n : Nat} (i : Nat) (p : LT.lt i (hPow 2 n)) : BitVec n where +protected def BitVec.ofNatLT {w : Nat} (i : Nat) (p : LT.lt i (hPow 2 w)) : BitVec w where toFin := ⟨i, p⟩ /-- @@ -2061,14 +2061,14 @@ Return the underlying `Nat` that represents a bitvector. This is O(1) because `BitVec` is a (zero-cost) wrapper around a `Nat`. -/ -protected def BitVec.toNat (x : BitVec n) : Nat := x.toFin.val +protected def BitVec.toNat (x : BitVec w) : Nat := x.toFin.val -instance : LT (BitVec n) where lt := (LT.lt ·.toNat ·.toNat) -instance (x y : BitVec n) : Decidable (LT.lt x y) := +instance : LT (BitVec w) where lt := (LT.lt ·.toNat ·.toNat) +instance (x y : BitVec w) : Decidable (LT.lt x y) := inferInstanceAs (Decidable (LT.lt x.toNat y.toNat)) -instance : LE (BitVec n) where le := (LE.le ·.toNat ·.toNat) -instance (x y : BitVec n) : Decidable (LE.le x y) := +instance : LE (BitVec w) where le := (LE.le ·.toNat ·.toNat) +instance (x y : BitVec w) : Decidable (LE.le x y) := inferInstanceAs (Decidable (LE.le x.toNat y.toNat)) /-- The number of distinct values representable by `UInt8`, that is, `2^8 = 256`. -/ diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean index 96eaa23b4e..d5b45be7d1 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean @@ -57,16 +57,13 @@ private def getPowFn (type : Expr) (u : Level) (commRingInst : Expr) : GoalM Exp throwError "instance for power operator{indentExpr inst}\nis not definitionally equal to the `Grind.CommRing` one{indentExpr inst'}" internalizeFn <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst -private def getIntCastFn (type : Expr) (u : Level) (_commRingInst : Expr) : GoalM Expr := do +private def getIntCastFn (type : Expr) (u : Level) (commRingInst : Expr) : GoalM Expr := do let instType := mkApp (mkConst ``IntCast [u]) type let .some inst ← trySynthInstance instType | throwError "failed to find instance for ring intCast{indentExpr instType}" - -- TODO uncomment after we fix `CommRing` definition - /- - let inst' := mkApp2 (mkConst ``Grind.CommRing.intCastInst [u]) type commRingInst + let inst' := mkApp2 (mkConst ``Grind.CommRing.intCast [u]) type commRingInst unless (← withDefault <| isDefEq inst inst') do throwError "instance for intCast{indentExpr inst}\nis not definitionally equal to the `Grind.CommRing` one{indentExpr inst'}" - -/ internalizeFn <| mkApp2 (mkConst ``IntCast.intCast [u]) type inst private def getNatCastFn (type : Expr) (u : Level) (commRingInst : Expr) : GoalM Expr := do