feat: add instance Grind.CommRing (Fin n) (#8276)

This PR adds the instances `Grind.CommRing (Fin n)` and `Grind.IsCharP
(Fin n) n`. New tests:
```lean
example (x y z : Fin 13) :
    (x + y + z) ^ 2 = x ^ 2 + y ^ 2 + z ^ 2 + 2 * (x * y + y * z + z * x) := by
  grind +ring

example (x y : Fin 17) : (x + y) ^ 3 = x ^ 3 + y ^ 3 + 3 * x * y * (x + y) := by
  grind +ring

example (x y : Fin 19) : (x - y) * (x ^ 2 + x * y + y ^ 2) = x ^ 3 - y ^ 3 := by
  grind +ring
```

---------

Co-authored-by: Kim Morrison <kim@tqft.net>
This commit is contained in:
Leonardo de Moura 2025-05-13 05:09:02 -07:00 committed by GitHub
parent 2299c3c9ec
commit e0a266780b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 312 additions and 75 deletions

View file

@ -150,7 +150,7 @@ with `BitVec.toInt` results in the value `i.bmod (2^n)`.
protected def ofInt (n : Nat) (i : Int) : BitVec n := .ofNatLT (i % (Int.ofNat (2^n))).toNat (by
apply (Int.toNat_lt _).mpr
· apply Int.emod_lt_of_pos
exact Int.ofNat_pos.mpr (Nat.two_pow_pos _)
exact Int.natCast_pos.mpr (Nat.two_pow_pos _)
· apply Int.emod_nonneg
intro eq
apply Nat.ne_of_gt (Nat.two_pow_pos n)

View file

@ -315,6 +315,12 @@ theorem ofFin_ofNat (n : Nat) :
ofFin (no_index (OfNat.ofNat n : Fin (2^w))) = OfNat.ofNat n := by
simp only [OfNat.ofNat, Fin.ofNat', BitVec.ofNat, Nat.and_two_pow_sub_one_eq_mod]
@[simp] theorem ofFin_neg {x : Fin (2 ^ w)} : ofFin (-x) = -(ofFin x) := by
rfl
@[simp, norm_cast] theorem ofFin_natCast (n : Nat) : ofFin (n : Fin (2^w)) = (n : BitVec w) := by
rfl
theorem eq_of_toFin_eq : ∀ {x y : BitVec w}, x.toFin = y.toFin → x = y
| ⟨_, _⟩, ⟨_, _⟩, rfl => rfl
@ -330,6 +336,9 @@ theorem toFin_zero : toFin (0 : BitVec w) = 0 := rfl
theorem toFin_one : toFin (1 : BitVec w) = 1 := by
rw [toFin_inj]; simp only [ofNat_eq_ofNat, ofFin_ofNat]
@[simp, norm_cast] theorem toFin_natCast (n : Nat) : toFin (n : BitVec w) = (n : Fin (2^w)) := by
rfl
@[simp] theorem toNat_ofBool (b : Bool) : (ofBool b).toNat = b.toNat := by
cases b <;> rfl
@ -672,6 +681,10 @@ theorem toInt_ne {x y : BitVec n} : x.toInt ≠ y.toInt ↔ x ≠ y := by
theorem toInt_ofNat {n : Nat} (x : Nat) : (BitVec.ofNat n x).toInt = (x : Int).bmod (2^n) := by
simp [toInt_eq_toNat_bmod, -Int.natCast_pow]
@[simp] theorem toInt_ofFin {w : Nat} (x : Fin (2^w)) :
(BitVec.ofFin x).toInt = Int.bmod x (2^w) := by
simp [toInt_eq_toNat_bmod]
@[simp] theorem toInt_ofInt {n : Nat} (i : Int) :
(BitVec.ofInt n i).toInt = i.bmod (2^n) := by
have _ := Nat.two_pow_pos n
@ -777,7 +790,6 @@ theorem le_two_mul_toInt {w : Nat} {x : BitVec w} : -2 ^ w ≤ 2 * x.toInt := by
simp only [Nat.zero_lt_succ, Nat.mul_lt_mul_left, Int.natCast_mul, Int.cast_ofNat_Int]
norm_cast; omega
theorem le_toInt {w : Nat} (x : BitVec w) : -2 ^ (w - 1) ≤ x.toInt := by
by_cases h : w = 0
· subst h
@ -793,6 +805,17 @@ theorem le_toInt {w : Nat} (x : BitVec w) : -2 ^ (w - 1) ≤ x.toInt := by
· simpa [Int.mul_comm _ 2] using le_two_mul_toInt
· simpa [Int.mul_comm _ 2] using two_mul_toInt_lt
@[simp] theorem toNat_intCast {w : Nat} (x : Int) : (x : BitVec w).toNat = (x % 2^w).toNat := by
change (BitVec.ofInt w x).toNat = _
simp
@[simp] theorem toInt_intCast {w : Nat} (x : Int) : (x : BitVec w).toInt = Int.bmod x (2^w) := by
rw [toInt_eq_toNat_bmod, toNat_intCast, Int.natCast_toNat_eq_self.mpr]
· have h : (2 ^ w : Int) = (2 ^ w : Nat) := by simp
rw [h, Int.emod_bmod]
· apply Int.emod_nonneg
exact Int.pow_ne_zero (by decide)
/-! ### sle/slt -/
/--

View file

@ -230,6 +230,19 @@ instance : ShiftRight (Fin n) where
instance instOfNat {n : Nat} [NeZero n] {i : Nat} : OfNat (Fin n) i where
ofNat := Fin.ofNat' n i
/-- If you actually have an element of `Fin n`, then the `n` is always positive -/
protected theorem pos (i : Fin n) : 0 < n :=
Nat.lt_of_le_of_lt (Nat.zero_le _) i.2
/-- Negation on `Fin n` -/
instance neg (n : Nat) : Neg (Fin n) :=
⟨fun a => ⟨(n - a) % n, Nat.mod_lt _ a.pos⟩⟩
theorem neg_def (a : Fin n) : -a = ⟨(n - a) % n, Nat.mod_lt _ a.pos⟩ := rfl
protected theorem coe_neg (a : Fin n) : ((-a : Fin n) : Nat) = (n - a) % n :=
rfl
instance instInhabited {n : Nat} [NeZero n] : Inhabited (Fin n) where
default := 0
@ -247,10 +260,6 @@ theorem modn_lt : ∀ {m : Nat} (i : Fin n), m > 0 → (modn i m).val < m
theorem val_lt_of_le (i : Fin b) (h : b ≤ n) : i.val < n :=
Nat.lt_of_lt_of_le i.isLt h
/-- If you actually have an element of `Fin n`, then the `n` is always positive -/
protected theorem pos (i : Fin n) : 0 < n :=
Nat.lt_of_le_of_lt (Nat.zero_le _) i.2
/--
The greatest value of `Fin (n+1)`, namely `n`.

View file

@ -8,6 +8,7 @@ module
prelude
import Init.Data.Fin.Basic
import Init.Data.Nat.Lemmas
import Init.Data.Int.DivMod.Lemmas
import Init.Ext
import Init.ByCases
import Init.Conv
@ -99,6 +100,21 @@ theorem dite_val {n : Nat} {c : Prop} [Decidable c] {x y : Fin n} :
(if c then x else y).val = if c then x.val else y.val := by
by_cases c <;> simp [*]
instance (n : Nat) [NeZero n] : NatCast (Fin n) where
natCast a := Fin.ofNat' n a
def intCast [NeZero n] (a : Int) : Fin n :=
if 0 ≤ a then
Fin.ofNat' n a.natAbs
else
- Fin.ofNat' n a.natAbs
instance (n : Nat) [NeZero n] : IntCast (Fin n) where
intCast := Fin.intCast
theorem intCast_def {n : Nat} [NeZero n] (x : Int) :
(x : Fin n) = if 0 ≤ x then Fin.ofNat' n x.natAbs else -Fin.ofNat' n x.natAbs := rfl
/-! ### order -/
theorem le_def {a b : Fin n} : a ≤ b ↔ a.1 ≤ b.1 := .rfl
@ -156,7 +172,7 @@ protected theorem eq_or_lt_of_le {a b : Fin n} : a ≤ b → a = b a < b :=
protected theorem lt_or_eq_of_le {a b : Fin n} : a ≤ b → a < b a = b := by
rw [Fin.ext_iff]; exact Nat.lt_or_eq_of_le
theorem is_le (i : Fin (n + 1)) : i ≤ n := Nat.le_of_lt_succ i.is_lt
theorem is_le (i : Fin (n + 1)) : i.1 ≤ n := Nat.le_of_lt_succ i.is_lt
@[simp] theorem is_le' {a : Fin n} : a ≤ n := Nat.le_of_lt a.is_lt
@ -219,7 +235,7 @@ theorem rev_eq {n a : Nat} (i : Fin (n + 1)) (h : n = a + i) :
/-! ### last -/
@[simp] theorem val_last (n : Nat) : last n = n := rfl
@[simp] theorem val_last (n : Nat) : (last n).1 = n := rfl
@[simp] theorem last_zero : (Fin.last 0 : Fin 1) = 0 := by
ext
@ -260,7 +276,7 @@ theorem subsingleton_iff_le_one : Subsingleton (Fin n) ↔ n ≤ 1 := by
(match n with | 0 | 1 | n+2 => ?_) <;> try simp
· exact ⟨nofun⟩
· exact ⟨fun ⟨0, _⟩ ⟨0, _⟩ => rfl⟩
· exact iff_of_false (fun h => Fin.ne_of_lt zero_lt_one (h.elim ..)) (of_decide_eq_false rfl)
· exact fun h => by have := zero_lt_one (n := n); simp_all [h.elim 0 1]
instance subsingleton_zero : Subsingleton (Fin 0) := subsingleton_iff_le_one.2 (by decide)
@ -925,6 +941,15 @@ theorem addCases_right {m n : Nat} {motive : Fin (m + n) → Sort _} {left right
have : ¬(natAdd m i : Nat) < m := Nat.not_lt.2 (le_coe_natAdd ..)
rw [addCases, dif_neg this]; exact eq_of_heq <| (eqRec_heq _ _).trans (by congr 1; simp)
/-! ### zero -/
@[simp, norm_cast]
theorem val_eq_zero_iff [NeZero n] {a : Fin n} : a.val = 0 ↔ a = 0 := by
rw [Fin.ext_iff, val_zero]
theorem val_ne_zero_iff [NeZero n] {a : Fin n} : a.val ≠ 0 ↔ a ≠ 0 :=
not_congr val_eq_zero_iff
/-! ### add -/
theorem ofNat'_add [NeZero n] (x : Nat) (y : Fin n) :
@ -984,6 +1009,17 @@ theorem coe_sub_iff_lt {a b : Fin n} : (↑(a - b) : Nat) = n + a - b ↔ a < b
rw [Nat.mod_eq_of_lt]
all_goals omega
/-! ### neg -/
theorem val_neg {n : Nat} [NeZero n] (x : Fin n) :
(-x).val = if x = 0 then 0 else n - x.val := by
change (n - ↑x) % n = _
split <;> rename_i h
· simp_all
· rw [Nat.mod_eq_of_lt]
have := Fin.val_ne_zero_iff.mpr h
omega
/-! ### mul -/
theorem ofNat'_mul [NeZero n] (x : Nat) (y : Fin n) :

View file

@ -108,7 +108,7 @@ theorem resolve_left_lt_lcm (a c d p x : Int) (a_pos : 0 < a) (d_pos : 0 < d) (h
resolve_left a c d p x < lcm a (a * d / gcd (a * d) c) := by
simp only [h₁, resolve_left_eq, resolve_left', add_of_le, Int.ofNat_lt]
exact Nat.mod_lt _ (Nat.pos_of_ne_zero (lcm_ne_zero (Int.ne_of_gt a_pos)
(Int.ne_of_gt (Int.ediv_pos_of_pos_of_dvd (Int.mul_pos a_pos d_pos) (Int.ofNat_nonneg _)
(Int.ne_of_gt (Int.ediv_pos_of_pos_of_dvd (Int.mul_pos a_pos d_pos) (Int.natCast_nonneg _)
(gcd_dvd_left _ _)))))
theorem resolve_left_ineq (a c d p x : Int) (a_pos : 0 < a) (b_pos : 0 < b)

View file

@ -23,6 +23,8 @@ open Nat (succ)
namespace Int
@[simp high] theorem natCast_eq_zero {n : Nat} : (n : Int) = 0 ↔ n = 0 := by omega
protected theorem exists_add_of_le {a b : Int} (h : a ≤ b) : ∃ (c : Nat), b = a + c :=
⟨(b - a).toNat, by omega⟩
@ -143,6 +145,20 @@ theorem dvd_of_mul_dvd_mul_left {a m n : Int} (ha : a ≠ 0) (h : a * m a *
theorem dvd_of_mul_dvd_mul_right {a m n : Int} (ha : a ≠ 0) (h : m * a n * a) : m n :=
dvd_of_mul_dvd_mul_left ha (by simpa [Int.mul_comm] using h)
@[norm_cast] theorem natCast_dvd_natCast {m n : Nat} : (↑m : Int) ↑n ↔ m n where
mp := by
rintro ⟨a, h⟩
obtain rfl | hm := m.eq_zero_or_pos
· simpa using h
have ha : 0 ≤ a := Int.not_lt.1 fun ha ↦ by
simpa [← h, Int.not_lt.2 (Int.natCast_nonneg _)]
using Int.mul_neg_of_pos_of_neg (natCast_pos.2 hm) ha
match a, ha with
| (a : Nat), _ =>
norm_cast at h
exact ⟨a, h⟩
mpr := by rintro ⟨a, rfl⟩; simp [Int.dvd_mul_right]
/-! ### *div zero -/
@[simp] protected theorem zero_tdiv : ∀ b : Int, tdiv 0 b = 0
@ -1031,6 +1047,27 @@ theorem ediv_dvd_of_dvd {m n : Int} (hmn : m n) : n / m n := by
· obtain ⟨a, ha⟩ := hmn
simp [ha, Int.mul_ediv_cancel_left _ hm, Int.dvd_mul_left]
theorem emod_natAbs_of_nonneg {x : Int} (h : 0 ≤ x) {n : Nat} :
x.natAbs % n = (x % n).toNat := by
match x, h with
| (x : Nat), _ => rw [Int.natAbs_natCast, Int.ofNat_mod_ofNat, Int.toNat_natCast]
theorem emod_natAbs_of_neg {x : Int} (h : x < 0) {n : Nat} (w : n ≠ 0) :
x.natAbs % n = if (n : Int) x then 0 else n - (x % n).toNat := by
match x, h with
| -(x + 1 : Nat), _ =>
rw [Int.natAbs_neg]
rw [Int.natAbs_cast]
rw [Int.neg_emod]
simp only [Int.dvd_neg]
simp only [Int.natCast_dvd_natCast]
split <;> rename_i h
· rw [Nat.mod_eq_zero_of_dvd h]
· rw [← Int.natCast_emod]
simp only [Int.natAbs_natCast]
have : (x + 1) % n < n := Nat.mod_lt (x + 1) (by omega)
omega
/-! ### `/` and ordering -/
protected theorem ediv_mul_le (a : Int) {b : Int} (H : b ≠ 0) : a / b * b ≤ a :=
@ -1308,7 +1345,7 @@ theorem sign_tdiv (a b : Int) : sign (a.tdiv b) = if natAbs a < natAbs b then 0
theorem ofNat_tmod (m n : Nat) : (↑(m % n) : Int) = tmod m n := rfl
theorem tmod_nonneg : ∀ {a : Int} (b : Int), 0 ≤ a → 0 ≤ tmod a b
| ofNat _, -[_+1], _ | ofNat _, ofNat _, _ => ofNat_nonneg _
| ofNat _, -[_+1], _ | ofNat _, ofNat _, _ => natCast_nonneg _
@[simp] theorem tmod_neg (a b : Int) : tmod a (-b) = tmod a b := by
rw [tmod_def, tmod_def, Int.tdiv_neg, Int.neg_mul_neg]
@ -1321,7 +1358,7 @@ theorem tmod_lt_of_pos (a : Int) {b : Int} (H : 0 < b) : tmod a b < b :=
match a, b, eq_succ_of_zero_lt H with
| ofNat _, _, ⟨n, rfl⟩ => ofNat_lt.2 <| Nat.mod_lt _ n.succ_pos
| -[_+1], _, ⟨n, rfl⟩ => Int.lt_of_le_of_lt
(Int.neg_nonpos_of_nonneg <| Int.ofNat_nonneg _) (ofNat_pos.2 n.succ_pos)
(Int.neg_nonpos_of_nonneg <| natCast_nonneg _) (natCast_pos.2 n.succ_pos)
theorem lt_tmod_of_pos (a : Int) {b : Int} (H : 0 < b) : -b < tmod a b :=
match a, b, eq_succ_of_zero_lt H with
@ -1724,7 +1761,7 @@ protected theorem tdiv_mul_le (a : Int) {b : Int} (hb : b ≠ 0) : a.tdiv b * b
· simp_all [tmod_nonneg]
· match b, hb with
| .ofNat (b + 1), _ =>
have := lt_tmod_of_pos a (Int.ofNat_pos.2 (b.succ_pos))
have := lt_tmod_of_pos a (natCast_pos.2 (b.succ_pos))
simp_all
omega
| .negSucc b, _ =>
@ -2679,8 +2716,8 @@ theorem le_bmod {x : Int} {m : Nat} (h : 0 < m) : - (m/2) ≤ Int.bmod x m := by
have v : (m : Int) % 2 = 0 (m : Int) % 2 = 1 := emod_two_eq _
split <;> rename_i w
· refine Int.le_trans ?_ (Int.emod_nonneg _ ?_)
· exact Int.neg_nonpos_of_nonneg (Int.ediv_nonneg (Int.ofNat_nonneg _) (by decide))
· exact Int.ne_of_gt (ofNat_pos.mpr h)
· exact Int.neg_nonpos_of_nonneg (Int.ediv_nonneg (natCast_nonneg _) (by decide))
· exact Int.ne_of_gt (natCast_pos.mpr h)
· simp [Int.not_lt] at w
refine Int.le_trans ?_ (Int.sub_le_sub_right w _)
rw [← ediv_add_emod m 2]
@ -2713,7 +2750,7 @@ theorem bmod_lt {x : Int} {m : Nat} (h : 0 < m) : bmod x m < (m + 1) / 2 := by
· assumption
· apply Int.lt_of_lt_of_le
· show _ < 0
have : x % m < m := emod_lt_of_pos x (ofNat_pos.mpr h)
have : x % m < m := emod_lt_of_pos x (natCast_pos.mpr h)
exact Int.sub_neg_of_lt this
· exact Int.le.intro_sub _ rfl

View file

@ -594,4 +594,6 @@ protected theorem natCast_zero : ((0 : Nat) : Int) = (0 : Int) := rfl
protected theorem natCast_one : ((1 : Nat) : Int) = (1 : Int) := rfl
@[simp, norm_cast] theorem natAbs_cast (n : Nat) : natAbs ↑n = n := rfl
end Int

View file

@ -25,7 +25,7 @@ namespace Int
@[simp] protected theorem neg_nonpos_iff (i : Int) : -i ≤ 0 ↔ 0 ≤ i := by omega
@[simp] theorem zero_le_ofNat (n : Nat) : 0 ≤ ((no_index (OfNat.ofNat n)) : Int) :=
ofNat_nonneg _
natCast_nonneg _
@[simp] theorem neg_natCast_le_natCast (n m : Nat) : -(n : Int) ≤ (m : Int) :=
Int.le_trans (by simp) (ofNat_zero_le m)
@ -52,25 +52,17 @@ protected theorem ofNat_add_one_out (n : Nat) : ↑n + (1 : Int) = ↑(Nat.succ
@[norm_cast] theorem natCast_inj {m n : Nat} : (m : Int) = (n : Int) ↔ m = n := ofNat_inj
@[simp, norm_cast] theorem natAbs_cast (n : Nat) : natAbs ↑n = n := rfl
@[norm_cast]
protected theorem natCast_sub {n m : Nat} : n ≤ m → (↑(m - n) : Int) = ↑m - ↑n := ofNat_sub
@[simp high] theorem natCast_eq_zero {n : Nat} : (n : Int) = 0 ↔ n = 0 := by omega
theorem natCast_ne_zero {n : Nat} : (n : Int) ≠ 0 ↔ n ≠ 0 := by omega
theorem natCast_ne_zero_iff_pos {n : Nat} : (n : Int) ≠ 0 ↔ 0 < n := by omega
@[simp high] theorem natCast_pos {n : Nat} : (0 : Int) < n ↔ 0 < n := by omega
theorem natCast_succ_pos (n : Nat) : 0 < (n.succ : Int) := natCast_pos.2 n.succ_pos
@[simp high] theorem natCast_nonpos_iff {n : Nat} : (n : Int) ≤ 0 ↔ n = 0 := by omega
theorem natCast_nonneg (n : Nat) : 0 ≤ (n : Int) := ofNat_le.2 (Nat.zero_le _)
@[simp] theorem sign_natCast_add_one (n : Nat) : sign (n + 1) = 1 := rfl
@[simp, norm_cast] theorem cast_id {n : Int} : Int.cast n = n := rfl

View file

@ -77,8 +77,14 @@ theorem lt.dest {a b : Int} (h : a < b) : ∃ n : Nat, a + Nat.succ n = b :=
@[simp, norm_cast] theorem ofNat_lt {n m : Nat} : (↑n : Int) < ↑m ↔ n < m := by
rw [lt_iff_add_one_le, ← natCast_succ, ofNat_le]; rfl
@[simp, norm_cast] theorem ofNat_pos {n : Nat} : 0 < (↑n : Int) ↔ 0 < n := ofNat_lt
@[simp, norm_cast] theorem natCast_pos {n : Nat} : (0 : Int) < n ↔ 0 < n := ofNat_lt
@[deprecated natCast_pos (since := "2025-05-13"), simp high]
theorem ofNat_pos {n : Nat} : 0 < (↑n : Int) ↔ 0 < n := ofNat_lt
theorem natCast_nonneg (n : Nat) : 0 ≤ (n : Int) := ⟨_⟩
@[deprecated natCast_nonneg (since := "2025-05-13")]
theorem ofNat_nonneg (n : Nat) : 0 ≤ (n : Int) := ⟨_⟩
theorem ofNat_succ_pos (n : Nat) : 0 < (succ n : Int) := ofNat_lt.2 <| Nat.succ_pos _
@ -475,7 +481,7 @@ instance : Std.IdempotentOp (α := Int) max := ⟨Int.max_self⟩
protected theorem mul_nonneg {a b : Int} (ha : 0 ≤ a) (hb : 0 ≤ b) : 0 ≤ a * b := by
let ⟨n, hn⟩ := eq_ofNat_of_zero_le ha
let ⟨m, hm⟩ := eq_ofNat_of_zero_le hb
rw [hn, hm, ← natCast_mul]; apply ofNat_nonneg
rw [hn, hm, ← natCast_mul]; apply natCast_nonneg
protected theorem mul_pos {a b : Int} (ha : 0 < a) (hb : 0 < b) : 0 < a * b := by
let ⟨n, hn⟩ := eq_succ_of_zero_lt ha
@ -1253,7 +1259,7 @@ theorem neg_of_sign_eq_neg_one : ∀ {a : Int}, sign a = -1 → a < 0
| 0 => rfl
| .ofNat (_ + 1) =>
simp +decide only [sign, true_iff]
exact Int.le_add_one (ofNat_nonneg _)
exact Int.le_add_one (natCast_nonneg _)
| .negSucc _ => simp +decide [sign]
@[deprecated sign_nonneg_iff (since := "2025-03-11")] abbrev sign_nonneg := @sign_nonneg_iff

View file

@ -29,6 +29,11 @@ protected theorem pow_nonneg {n : Int} {m : Nat} : 0 ≤ n → 0 ≤ n ^ m := by
| zero => simp
| succ m ih => exact fun h => Int.mul_nonneg (ih h) h
protected theorem pow_ne_zero {n : Int} {m : Nat} : n ≠ 0 → n ^ m ≠ 0 := by
induction m with
| zero => simp
| succ m ih => exact fun h => Int.mul_ne_zero (ih h) h
@[deprecated Nat.pow_le_pow_left (since := "2025-02-17")]
abbrev pow_le_pow_of_le_left := @Nat.pow_le_pow_left

View file

@ -363,8 +363,6 @@ theorem drop_take : ∀ {i j : Nat} {l : List α}, drop i (take j l) = take (j -
| _, _, [] => by simp
| i+1, j+1, h :: t => by
simp [take_succ_cons, drop_succ_cons, drop_take]
congr 1
omega
@[simp] theorem drop_take_self : drop i (take i l) = [] := by
rw [drop_take]

View file

@ -575,15 +575,15 @@ expression `(a >>> b).toUInt8` is not a function of `a.toUInt8` and `b.toUInt8`.
BitVec.toNat_umod, toNat_toBitVec, toNat_ofNat', BitVec.toNat_ofNat, Nat.mod_two_pow_self]
rw [Nat.mod_mod_of_dvd _ (by cases System.Platform.numBits_eq <;> simp_all)]
@[simp] theorem UInt8.ofFin_shiftLeft (a b : Fin UInt8.size) (hb : b < 8) : UInt8.ofFin (a <<< b) = UInt8.ofFin a <<< UInt8.ofFin b :=
@[simp] theorem UInt8.ofFin_shiftLeft (a b : Fin UInt8.size) (hb : b.val < 8) : UInt8.ofFin (a <<< b) = UInt8.ofFin a <<< UInt8.ofFin b :=
UInt8.toFin_inj.1 (by simp [UInt8.toFin_shiftLeft (ofFin a) (ofFin b) hb])
@[simp] theorem UInt16.ofFin_shiftLeft (a b : Fin UInt16.size) (hb : b < 16) : UInt16.ofFin (a <<< b) = UInt16.ofFin a <<< UInt16.ofFin b :=
@[simp] theorem UInt16.ofFin_shiftLeft (a b : Fin UInt16.size) (hb : b.val < 16) : UInt16.ofFin (a <<< b) = UInt16.ofFin a <<< UInt16.ofFin b :=
UInt16.toFin_inj.1 (by simp [UInt16.toFin_shiftLeft (ofFin a) (ofFin b) hb])
@[simp] theorem UInt32.ofFin_shiftLeft (a b : Fin UInt32.size) (hb : b < 32) : UInt32.ofFin (a <<< b) = UInt32.ofFin a <<< UInt32.ofFin b :=
@[simp] theorem UInt32.ofFin_shiftLeft (a b : Fin UInt32.size) (hb : b.val < 32) : UInt32.ofFin (a <<< b) = UInt32.ofFin a <<< UInt32.ofFin b :=
UInt32.toFin_inj.1 (by simp [UInt32.toFin_shiftLeft (ofFin a) (ofFin b) hb])
@[simp] theorem UInt64.ofFin_shiftLeft (a b : Fin UInt64.size) (hb : b < 64) : UInt64.ofFin (a <<< b) = UInt64.ofFin a <<< UInt64.ofFin b :=
@[simp] theorem UInt64.ofFin_shiftLeft (a b : Fin UInt64.size) (hb : b.val < 64) : UInt64.ofFin (a <<< b) = UInt64.ofFin a <<< UInt64.ofFin b :=
UInt64.toFin_inj.1 (by simp [UInt64.toFin_shiftLeft (ofFin a) (ofFin b) hb])
@[simp] theorem USize.ofFin_shiftLeft (a b : Fin USize.size) (hb : b < System.Platform.numBits) : USize.ofFin (a <<< b) = USize.ofFin a <<< USize.ofFin b :=
@[simp] theorem USize.ofFin_shiftLeft (a b : Fin USize.size) (hb : b.val < System.Platform.numBits) : USize.ofFin (a <<< b) = USize.ofFin a <<< USize.ofFin b :=
USize.toFin_inj.1 (by simp [USize.toFin_shiftLeft (ofFin a) (ofFin b) hb])
@[simp] theorem UInt8.ofFin_shiftLeft_mod (a b : Fin UInt8.size) : UInt8.ofFin (a <<< (b % 8)) = UInt8.ofFin a <<< UInt8.ofFin b :=
@ -670,15 +670,15 @@ expression `(a >>> b).toUInt8` is not a function of `a.toUInt8` and `b.toUInt8`.
BitVec.toNat_umod, toNat_toBitVec, toNat_ofNat', BitVec.toNat_ofNat, Nat.mod_two_pow_self]
rw [Nat.mod_mod_of_dvd _ (by cases System.Platform.numBits_eq <;> simp_all)]
@[simp] theorem UInt8.ofFin_shiftRight (a b : Fin UInt8.size) (hb : b < 8) : UInt8.ofFin (a >>> b) = UInt8.ofFin a >>> UInt8.ofFin b :=
@[simp] theorem UInt8.ofFin_shiftRight (a b : Fin UInt8.size) (hb : b.val < 8) : UInt8.ofFin (a >>> b) = UInt8.ofFin a >>> UInt8.ofFin b :=
UInt8.toFin_inj.1 (by simp [UInt8.toFin_shiftRight (ofFin a) (ofFin b) hb])
@[simp] theorem UInt16.ofFin_shiftRight (a b : Fin UInt16.size) (hb : b < 16) : UInt16.ofFin (a >>> b) = UInt16.ofFin a >>> UInt16.ofFin b :=
@[simp] theorem UInt16.ofFin_shiftRight (a b : Fin UInt16.size) (hb : b.val < 16) : UInt16.ofFin (a >>> b) = UInt16.ofFin a >>> UInt16.ofFin b :=
UInt16.toFin_inj.1 (by simp [UInt16.toFin_shiftRight (ofFin a) (ofFin b) hb])
@[simp] theorem UInt32.ofFin_shiftRight (a b : Fin UInt32.size) (hb : b < 32) : UInt32.ofFin (a >>> b) = UInt32.ofFin a >>> UInt32.ofFin b :=
@[simp] theorem UInt32.ofFin_shiftRight (a b : Fin UInt32.size) (hb : b.val < 32) : UInt32.ofFin (a >>> b) = UInt32.ofFin a >>> UInt32.ofFin b :=
UInt32.toFin_inj.1 (by simp [UInt32.toFin_shiftRight (ofFin a) (ofFin b) hb])
@[simp] theorem UInt64.ofFin_shiftRight (a b : Fin UInt64.size) (hb : b < 64) : UInt64.ofFin (a >>> b) = UInt64.ofFin a >>> UInt64.ofFin b :=
@[simp] theorem UInt64.ofFin_shiftRight (a b : Fin UInt64.size) (hb : b.val < 64) : UInt64.ofFin (a >>> b) = UInt64.ofFin a >>> UInt64.ofFin b :=
UInt64.toFin_inj.1 (by simp [UInt64.toFin_shiftRight (ofFin a) (ofFin b) hb])
@[simp] theorem USize.ofFin_shiftRight (a b : Fin USize.size) (hb : b < System.Platform.numBits) : USize.ofFin (a >>> b) = USize.ofFin a >>> USize.ofFin b :=
@[simp] theorem USize.ofFin_shiftRight (a b : Fin USize.size) (hb : b.val < System.Platform.numBits) : USize.ofFin (a >>> b) = USize.ofFin a >>> USize.ofFin b :=
USize.toFin_inj.1 (by simp [USize.toFin_shiftRight (ofFin a) (ofFin b) hb])
@[simp] theorem UInt8.ofFin_shiftRight_mod (a b : Fin UInt8.size) : UInt8.ofFin (a >>> (b % 8)) = UInt8.ofFin a >>> UInt8.ofFin b :=

View file

@ -27,3 +27,23 @@ instance (priority := 300) One.toOfNat1 {α} [One α] : OfNat α (nat_lit 1) whe
instance (priority := 200) One.ofOfNat1 {α} [OfNat α (nat_lit 1)] : One α where
one := 1
/--
The fundamental power operation in a monoid.
`npowRec n a = a*a*...*a` n times.
This function should not be used directly; it is often used to implement a `Pow M Nat` instance,
but end users should use the `a ^ n` notation instead.
-/
def npowRec [One M] [Mul M] : Nat → M → M
| 0, _ => 1
| n + 1, a => npowRec n a * a
/--
The fundamental scalar multiplication in an additive monoid.
`nsmulRec n a = a+a+...+a` n times.
This function should not be used directly;
it is often used to implement an instance for scalar multiplication.
-/
def nsmulRec [Zero M] [Add M] : Nat → M → M
| 0, _ => 0
| n + 1, a => nsmulRec n a + a

View file

@ -10,5 +10,6 @@ import Init.Grind.CommRing.Basic
import Init.Grind.CommRing.Int
import Init.Grind.CommRing.UInt
import Init.Grind.CommRing.SInt
import Init.Grind.CommRing.Fin
import Init.Grind.CommRing.BitVec
import Init.Grind.CommRing.Poly

View file

@ -0,0 +1,110 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
import Init.Grind.CommRing.Basic
import Init.Data.Fin.Lemmas
namespace Lean.Grind
namespace Fin
instance (n : Nat) [NeZero n] : NatCast (Fin n) where
natCast a := Fin.ofNat' n a
def intCast [NeZero n] (a : Int) : Fin n :=
if 0 ≤ a then
Fin.ofNat' n a.natAbs
else
- Fin.ofNat' n a.natAbs
instance (n : Nat) [NeZero n] : IntCast (Fin n) where
intCast := Fin.intCast
theorem intCast_def {n : Nat} [NeZero n] (x : Int) :
(x : Fin n) = if 0 ≤ x then Fin.ofNat' n x.natAbs else -Fin.ofNat' n x.natAbs := rfl
-- TODO: we should replace this at runtime with either repeated squaring,
-- or a GMP accelerated function.
def npow [NeZero n] (x : Fin n) (y : Nat) : Fin n := npowRec y x
instance [NeZero n] : HPow (Fin n) Nat (Fin n) where
hPow := Fin.npow
@[simp] theorem pow_zero [NeZero n] (a : Fin n) : a ^ 0 = 1 := rfl
@[simp] theorem pow_succ [NeZero n] (a : Fin n) (n : Nat) : a ^ (n+1) = a ^ n * a := rfl
theorem add_assoc (a b c : Fin n) : a + b + c = a + (b + c) := by
cases a; cases b; cases c; simp [Fin.add_def, Nat.add_assoc]
theorem add_comm (a b : Fin n) : a + b = b + a := by
cases a; cases b; simp [Fin.add_def, Nat.add_comm]
theorem add_zero [NeZero n] (a : Fin n) : a + 0 = a := by
cases a; simp [Fin.add_def]
next h => rw [Nat.mod_eq_of_lt h]
theorem neg_add_cancel [NeZero n] (a : Fin n) : -a + a = 0 := by
cases a; simp [Fin.add_def, Fin.neg_def, Fin.sub_def]
next h => rw [Nat.sub_add_cancel (Nat.le_of_lt h), Nat.mod_self]
theorem mul_assoc (a b c : Fin n) : a * b * c = a * (b * c) := by
cases a; cases b; cases c; simp [Fin.mul_def, Nat.mul_assoc]
theorem mul_comm (a b : Fin n) : a * b = b * a := by
cases a; cases b; simp [Fin.mul_def, Nat.mul_comm]
theorem zero_mul [NeZero n] (a : Fin n) : 0 * a = 0 := by
cases a; simp [Fin.mul_def]
theorem mul_one [NeZero n] (a : Fin n) : a * 1 = a := by
cases a; simp [Fin.mul_def, OfNat.ofNat]
next h => rw [Nat.mod_eq_of_lt h]
theorem left_distrib (a b c : Fin n) : a * (b + c) = a * b + a * c := by
cases a; cases b; cases c; simp [Fin.mul_def, Fin.add_def, Nat.left_distrib]
theorem ofNat_succ [NeZero n] (a : Nat) : OfNat.ofNat (α := Fin n) (a+1) = OfNat.ofNat a + 1 := by
simp [OfNat.ofNat, Fin.add_def, Fin.ofNat']
theorem sub_eq_add_neg [NeZero n] (a b : Fin n) : a - b = a + -b := by
cases a; cases b; simp [Fin.neg_def, Fin.sub_def, Fin.add_def, Nat.add_comm]
private theorem neg_neg [NeZero n] (a : Fin n) : - - a = a := by
cases a; simp [Fin.neg_def, Fin.sub_def];
next a h => cases a; simp; next a =>
rw [Nat.self_sub_mod n (a+1)]
have : NeZero (n - (a + 1)) := ⟨by omega⟩
rw [Nat.self_sub_mod, Nat.sub_sub_eq_min, Nat.min_eq_right (Nat.le_of_lt h)]
theorem intCast_neg [NeZero n] (i : Int) : Int.cast (R := Fin n) (-i) = - Int.cast (R := Fin n) i := by
simp [Int.cast, IntCast.intCast, Fin.intCast]; split <;> split <;> try omega
next h₁ h₂ => simp [Int.le_antisymm h₁ h₂, Fin.neg_def]
next => simp [Fin.neg_neg]
instance (n : Nat) [NeZero n] : CommRing (Fin n) where
add_assoc := Fin.add_assoc
add_comm := Fin.add_comm
add_zero := Fin.add_zero
neg_add_cancel := Fin.neg_add_cancel
mul_assoc := Fin.mul_assoc
mul_comm := Fin.mul_comm
mul_one := Fin.mul_one
left_distrib := Fin.left_distrib
zero_mul := Fin.zero_mul
pow_zero _ := rfl
pow_succ _ _ := rfl
ofNat_succ := Fin.ofNat_succ
sub_eq_add_neg := Fin.sub_eq_add_neg
intCast_neg := Fin.intCast_neg
instance (n : Nat) [NeZero n] : IsCharP (Fin n) n where
ofNat_eq_zero_iff x := by simp only [OfNat.ofNat, Fin.ofNat']; simp
end Fin
end Lean.Grind

View file

@ -121,7 +121,7 @@ theorem ofNat_natAbs (a : Int) : (a.natAbs : Int) = if 0 ≤ a then a else -a :=
rw [Int.natAbs.eq_def]
split <;> rename_i n
· simp only [Int.ofNat_eq_coe]
rw [if_pos (Int.ofNat_nonneg n)]
rw [if_pos (Int.natCast_nonneg n)]
· simp; rfl
theorem natAbs_dichotomy {a : Int} : 0 ≤ a ∧ a.natAbs = a a < 0 ∧ a.natAbs = -a := by

View file

@ -796,24 +796,6 @@ class AddZeroClass (M : Type u) extends Zero M, Add M where
protected zero_add : ∀ a : M, 0 + a = a
protected add_zero : ∀ a : M, a + 0 = a
section
variable {M : Type u}
/-- The fundamental power operation in a monoid. `npowRec n a = a*a*...*a` n times.
Use instead `a ^ n`, which has better definitional behavior. -/
def npowRec [One M] [Mul M] : Nat → M → M
| 0, _ => 1
| n + 1, a => npowRec n a * a
/-- The fundamental scalar multiplication in an additive monoid. `nsmulRec n a = a+a+...+a` n
times. Use instead `n • a`, which has better definitional behavior. -/
def nsmulRec [Zero M] [Add M] : Nat → M → M
| 0, _ => 0
| n + 1, a => nsmulRec n a + a
end
class AddMonoid (M : Type u) extends AddSemigroup M, AddZeroClass M where
protected nsmul : Nat → M → M
protected nsmul_zero : ∀ x, nsmul 0 x = 0 := by intros; rfl

View file

@ -39,7 +39,7 @@ info: fun x =>
this : ∀ (x : Fin 1), ∃ n, ↑x = n
-/
#guard_msgs in
#check fun (x : Fin 1) => show ∃ (n : Nat), x = n from
#check fun (x : Fin 1) => show ∃ (n : Nat), x = n from
match h : x.1 + 1 with
| 0 => Nat.noConfusion h
| n + 1 => ⟨n, Nat.succ.inj h⟩
@ -57,7 +57,7 @@ info: fun h =>
this : ∀ (h : Fin 1), ∃ n, ↑h = n
-/
#guard_msgs in
#check fun (h : Fin 1) => show ∃ (n : Nat), h = n from
#check fun (h : Fin 1) => show ∃ (n : Nat), h = n from
match h : h.1 + 1 with
| 0 => Nat.noConfusion h
| n + 1 => ⟨n, Nat.succ.inj h⟩
@ -76,7 +76,7 @@ info: fun h =>
this : ∀ (h : Fin 1), ∃ n, ↑h = n
-/
#guard_msgs in
#check fun (h : Fin 1) => show ∃ (n : Nat), h = n from
#check fun (h : Fin 1) => show ∃ (n : Nat), h = n from
match h : h.1 + 1 with
| 0 => Nat.noConfusion h
| h + 1 => ⟨_, Nat.succ.inj _

View file

@ -135,3 +135,27 @@ example [CommRing α] (a b c : α) (f : α → Nat)
example [CommRing α] [NoNatZeroDivisors α] (x y z : α) : 3*x = 1 → 3*z = 2 → 2*y = 2 → x + z + 3*y = 4 := by
grind +ring
example (x y : Fin 11) : x^2*y = 1 → x*y^2 = y → y*x = 1 := by
grind +ring
example (x y : Fin 11) : 3*x = 1 → 3*y = 2 → x + y = 1 := by
grind +ring
example (x y z : Fin 13) :
(x + y + z) ^ 2 = x ^ 2 + y ^ 2 + z ^ 2 + 2 * (x * y + y * z + z * x) := by
grind +ring
example (x y : Fin 17) : (x + y) ^ 3 = x ^ 3 + y ^ 3 + 3 * x * y * (x + y) := by
grind +ring
example (x y : Fin 19) : (x - y) * (x ^ 2 + x * y + y ^ 2) = x ^ 3 - y ^ 3 := by
grind +ring
example (x : Fin 19) : (1 + x) ^ 5 = x ^ 5 + 5 * x ^ 4 + 10 * x ^ 3 + 10 * x ^ 2 + 5 * x + 1 := by
grind +ring
example (x : Fin 10) : (1 + x) ^ 5 = x ^ 5 + 5 * x ^ 4 - 5 * x + 1 := by
grind +ring
example (x y : Fin 3) (h : x = y) : ((x + y) ^ 3 : Fin 3) = - x^3 := by grind +ring

View file

@ -319,8 +319,8 @@ fun a b =>
(Expr.mulL 3 ((Expr.var 1).add (Expr.mulL 2 (Expr.var 0))))).sub
(Expr.num 11))
2 (Poly.add 1 1 (Poly.add 2 0 (Poly.num 0))) 3 (Eq.refl true)))
_proof_2✝))
_proof_2✝)
_proof_3✝))
_proof_3✝)
(iff_self (2 a)))
-/
#guard_msgs (info) in

View file

@ -137,18 +137,10 @@ class AddZeroClass (M : Type u) extends Zero M, Add M where
zero_add : ∀ a : M, 0 + a = a
add_zero : ∀ a : M, a + 0 = a
def npowRec [One M] [Mul M] : Nat → M → M
| 0, _ => 1
| n + 1, a => a * npowRec n a
def nsmulRec [Zero M] [Add M] : Nat → M → M
| 0, _ => 0
| n + 1, a => a + nsmulRec n a
class AddMonoid (M : Type u) extends AddSemigroup M, AddZeroClass M where
nsmul : Nat → M → M := nsmulRec
nsmul_zero : ∀ x, nsmul 0 x = 0 := by intros; rfl
nsmul_succ : ∀ (n : Nat) (x), nsmul (n + 1) x = x + nsmul n x := by intros; rfl
nsmul_succ : ∀ (n : Nat) (x), nsmul (n + 1) x = nsmul n x + x := by intros; rfl
attribute [instance 150] AddSemigroup.toAdd
attribute [instance 50] AddZeroClass.toAdd
@ -156,7 +148,7 @@ attribute [instance 50] AddZeroClass.toAdd
class Monoid (M : Type u) extends Semigroup M, MulOneClass M where
npow : Nat → M → M := npowRec
npow_zero : ∀ x, npow 0 x = 1 := by intros; rfl
npow_succ : ∀ (n : Nat) (x), npow (n + 1) x = x * npow n x := by intros; rfl
npow_succ : ∀ (n : Nat) (x), npow (n + 1) x = npow n x * x := by intros; rfl
@[default_instance high] instance Monoid.Pow {M : Type _} [Monoid M] : Pow M Nat :=
⟨fun x n ↦ Monoid.npow n x⟩
@ -204,7 +196,7 @@ class SubNegMonoid (G : Type u) extends AddMonoid G, Neg G, Sub G where
sub_eq_add_neg : ∀ a b : G, a - b = a + -b := by intros; rfl
zsmul : Int → G → G := zsmulRec
zsmul_zero' : ∀ a : G, zsmul 0 a = 0 := by intros; rfl
zsmul_succ' (n : Nat) (a : G) : zsmul (Int.ofNat n.succ) a = a + zsmul (Int.ofNat n) a := by
zsmul_succ' (n : Nat) (a : G) : zsmul (Int.ofNat n.succ) a = zsmul (Int.ofNat n) a + a := by
intros; rfl
zsmul_neg' (n : Nat) (a : G) : zsmul (Int.negSucc n) a = -zsmul n.succ a := by intros; rfl