From 27a7a0a2bdb00e5d562275ff9e8f18ff5b8822d2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 17 Apr 2025 21:34:05 -0700 Subject: [PATCH] fix: `CommRing` multivariate polynomials (#8016) This PR fixes several issues in the `CommRing` multivariate polynomial library: 1. Replaces the previous array type with the universe polymorphic `RArray`. 2. Properly eliminates cancelled monomials. 3. Sorts monomials in decreasing order. 4. Marks the parameter `p` of the `IsCharP` class as an output parameter. 5. Adds `LawfulBEq` instances for the types `Power`, `Mon`, and `Poly`. --- src/Init/Grind/CommRing/Basic.lean | 2 +- src/Init/Grind/CommRing/SOM.lean | 104 +++++++++++++++++++++-------- src/Init/Grind/CommRing/UInt.lean | 8 +-- tests/lean/run/grind_som1.lean | 27 ++++++++ 4 files changed, 109 insertions(+), 32 deletions(-) create mode 100644 tests/lean/run/grind_som1.lean diff --git a/src/Init/Grind/CommRing/Basic.lean b/src/Init/Grind/CommRing/Basic.lean index 0e23e542bc..64adff4661 100644 --- a/src/Init/Grind/CommRing/Basic.lean +++ b/src/Init/Grind/CommRing/Basic.lean @@ -217,7 +217,7 @@ end CommRing open CommRing -class IsCharP (α : Type u) [CommRing α] (p : Nat) where +class IsCharP (α : Type u) [CommRing α] (p : outParam Nat) where ofNat_eq_zero_iff (p) : ∀ (x : Nat), OfNat.ofNat (α := α) x = 0 ↔ x % p = 0 namespace IsCharP diff --git a/src/Init/Grind/CommRing/SOM.lean b/src/Init/Grind/CommRing/SOM.lean index 11b08c4ee4..bd82b04fc3 100644 --- a/src/Init/Grind/CommRing/SOM.lean +++ b/src/Init/Grind/CommRing/SOM.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Init.Data.Nat.Lemmas import Init.Data.Ord +import Init.Data.RArray import Init.Grind.CommRing.Basic namespace Lean.Grind @@ -23,16 +24,6 @@ inductive Expr where | pow (a : Expr) (k : Nat) deriving Inhabited, BEq --- TODO: add support for universes to Lean.RArray -inductive RArray (α : Type u) : Type u where - | leaf : α → RArray α - | branch : Nat → RArray α → RArray α → RArray α - -def RArray.get (a : RArray α) (n : Nat) : α := - match a with - | .leaf x => x - | .branch p l r => if n < p then l.get n else r.get n - abbrev Context (α : Type u) := RArray α def Var.denote {α} (ctx : Context α) (v : Var) : α := @@ -52,6 +43,10 @@ structure Power where k : Nat deriving BEq, Repr +instance : LawfulBEq Power where + eq_of_beq {a} := by cases a <;> intro b <;> cases b <;> simp_all! [BEq.beq] + rfl := by intro a; cases a <;> simp! [BEq.beq] + def Power.varLt (p₁ p₂ : Power) : Bool := p₁.x.blt p₂.x @@ -67,6 +62,18 @@ inductive Mon where | cons (p : Power) (m : Mon) deriving BEq, Repr +instance : LawfulBEq Mon where + eq_of_beq {a} := by + induction a <;> intro b <;> cases b <;> simp_all! [BEq.beq] + next p₁ p₂ => cases p₁ <;> cases p₂ <;> simp <;> intros <;> simp [*] + next p₁ m₁ p₂ m₂ ih => + cases p₁ <;> cases p₂ <;> simp <;> intros <;> simp [*] + next h => exact ih h + rfl := by + intro a + induction a <;> simp! [BEq.beq] + assumption + def Mon.denote {α} [CommRing α] (ctx : Context α) : Mon → α | .leaf p => p.denote ctx | .cons p m => p.denote ctx * denote ctx m @@ -205,6 +212,20 @@ inductive Poly where | add (k : Int) (v : Mon) (p : Poly) deriving BEq +instance : LawfulBEq Poly where + eq_of_beq {a} := by + induction a <;> intro b <;> cases b <;> simp_all! [BEq.beq] + intro h₁ h₂ h₃ + next m₁ p₁ _ m₂ p₂ ih => + replace h₂ : m₁ == m₂ := h₂ + simp [ih h₃, eq_of_beq h₂] + rfl := by + intro a + induction a <;> simp! [BEq.beq] + next k m p ih => + show m == m ∧ p == p + simp [ih] + def Poly.denote [CommRing α] (ctx : Context α) (p : Poly) : α := match p with | .num k => Int.cast k @@ -216,10 +237,20 @@ def Poly.ofMon (m : Mon) : Poly := def Poly.ofVar (x : Var) : Poly := ofMon (Mon.ofVar x) +def Poly.isSorted : Poly → Bool + | .num _ => true + | .add _ _ (.num _) => true + | .add _ m₁ (.add k m₂ p) => m₁.grevlex m₂ == .gt && (Poly.add k m₂ p).isSorted + def Poly.addConst (p : Poly) (k : Int) : Poly := - match p with + bif k == 0 then + p + else + go p +where + go : Poly → Poly | .num k' => .num (k' + k) - | .add k' m p => .add k' m (addConst p k) + | .add k' m p => .add k' m (go p) def Poly.insert (k : Int) (m : Mon) (p : Poly) : Poly := bif k == 0 then @@ -232,13 +263,13 @@ where | .add k' m' p => match m.grevlex m' with | .eq => - let k'' := k + k' - bif k'' == 0 then + let k := k + k' + bif k == 0 then p else - .add k'' m p - | .lt => .add k m (.add k' m' p) - | .gt => .add k' m' (go p) + .add k m p + | .gt => .add k m (.add k' m' p) + | .lt => .add k' m' (go p) def Poly.concat (p₁ p₂ : Poly) : Poly := match p₁ with @@ -264,7 +295,11 @@ def Poly.mulMon (k : Int) (m : Mon) (p : Poly) : Poly := go p where go : Poly → Poly - | .num k' => .add (k*k') m (.num 0) + | .num k' => + bif k' == 0 then + .num 0 + else + .add (k*k') m (.num 0) | .add k' m' p => .add (k*k') (m.mul m') (go p) def Poly.combine (p₁ p₂ : Poly) : Poly := @@ -285,8 +320,8 @@ where go fuel p₁ p₂ else .add k m₁ (go fuel p₁ p₂) - | .lt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂)) - | .gt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂) + | .gt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂)) + | .lt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂) def Poly.mul (p₁ : Poly) (p₂ : Poly) : Poly := go p₁ (.num 0) @@ -344,8 +379,8 @@ where p else .add k'' m p - | .lt => .add k m (.add k' m' p) - | .gt => .add k' m' (go k p) + | .gt => .add k m (.add k' m' p) + | .lt => .add k' m' (go k p) def Poly.mulConstC (k : Int) (p : Poly) (c : Nat) : Poly := let k := k % c @@ -404,8 +439,8 @@ where go fuel p₁ p₂ else .add k m₁ (go fuel p₁ p₂) - | .lt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂)) - | .gt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂) + | .gt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂)) + | .lt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂) def Poly.mulC (p₁ : Poly) (p₂ : Poly) (c : Nat) : Poly := go p₁ (.num 0) @@ -556,9 +591,12 @@ theorem Poly.denote_ofVar {α} [CommRing α] (ctx : Context α) (x : Var) simp [ofVar, denote_ofMon, Mon.denote_ofVar] theorem Poly.denote_addConst {α} [CommRing α] (ctx : Context α) (p : Poly) (k : Int) : (addConst p k).denote ctx = p.denote ctx + k := by - fun_induction addConst <;> simp [addConst, denote, *] - next => rw [intCast_add] - next => simp [add_comm, add_left_comm, add_assoc] + simp [addConst, cond_eq_if]; split + next => simp [*, intCast_zero, add_zero] + next => + fun_induction addConst.go <;> simp [addConst.go, denote, *] + next => rw [intCast_add] + next => simp [add_comm, add_left_comm, add_assoc] theorem Poly.denote_insert {α} [CommRing α] (ctx : Context α) (k : Int) (m : Mon) (p : Poly) : (insert k m p).denote ctx = k * m.denote ctx + p.denote ctx := by @@ -595,6 +633,7 @@ theorem Poly.denote_mulMon {α} [CommRing α] (ctx : Context α) (k : Int) (m : next => simp [denote, *, intCast_zero, zero_mul] next => fun_induction mulMon.go <;> simp [mulMon.go, denote, *] + next h => simp +zetaDelta at h; simp [*, intCast_zero, mul_zero] next => simp [intCast_mul, intCast_zero, add_zero, mul_comm, mul_left_comm, mul_assoc] next => simp [Mon.denote_mul, intCast_mul, left_distrib, mul_comm, mul_left_comm, mul_assoc] @@ -635,6 +674,11 @@ theorem Expr.denote_toPoly {α} [CommRing α] (ctx : Context α) (e : Expr) next => rw [intCast_pow] next => simp [Poly.denote_ofMon, Mon.denote, Power.denote_eq] +theorem Expr.eq_of_toPoly_eq {α} [CommRing α] (ctx : Context α) (a b : Expr) (h : a.toPoly == b.toPoly) : a.denote ctx = b.denote ctx := by + have h := congrArg (Poly.denote ctx) (eq_of_beq h) + simp [denote_toPoly] at h + assumption + theorem Poly.denote_addConstC {α c} [CommRing α] [IsCharP α c] (ctx : Context α) (p : Poly) (k : Int) : (addConstC p k c).denote ctx = p.denote ctx + k := by fun_induction addConstC <;> simp [addConstC, denote, *] next => rw [IsCharP.intCast_emod, intCast_add] @@ -747,5 +791,11 @@ theorem Expr.denote_toPolyC {α c} [CommRing α] [IsCharP α c] (ctx : Context next => rw [IsCharP.intCast_emod, intCast_pow] next => simp [Poly.denote_ofMon, Mon.denote, Power.denote_eq] +theorem Expr.eq_of_toPolyC_eq {α c} [CommRing α] [IsCharP α c] (ctx : Context α) (a b : Expr) + (h : a.toPolyC c == b.toPolyC c) : a.denote ctx = b.denote ctx := by + have h := congrArg (Poly.denote ctx) (eq_of_beq h) + simp [denote_toPolyC] at h + assumption + end CommRing end Lean.Grind diff --git a/src/Init/Grind/CommRing/UInt.lean b/src/Init/Grind/CommRing/UInt.lean index e24df8e2ea..5385c65710 100644 --- a/src/Init/Grind/CommRing/UInt.lean +++ b/src/Init/Grind/CommRing/UInt.lean @@ -71,7 +71,7 @@ instance : CommRing UInt8 where pow_succ := UInt8.pow_succ ofNat_succ x := UInt8.ofNat_add x 1 -instance : IsCharP UInt8 (2 ^ 8) where +instance : IsCharP UInt8 256 where ofNat_eq_zero_iff {x} := by have : OfNat.ofNat x = UInt8.ofNat x := rfl simp [this, UInt8.ofNat_eq_iff_mod_eq_toNat] @@ -91,7 +91,7 @@ instance : CommRing UInt16 where pow_succ := UInt16.pow_succ ofNat_succ x := UInt16.ofNat_add x 1 -instance : IsCharP UInt16 (2 ^ 16) where +instance : IsCharP UInt16 65536 where ofNat_eq_zero_iff {x} := by have : OfNat.ofNat x = UInt16.ofNat x := rfl simp [this, UInt16.ofNat_eq_iff_mod_eq_toNat] @@ -111,7 +111,7 @@ instance : CommRing UInt32 where pow_succ := UInt32.pow_succ ofNat_succ x := UInt32.ofNat_add x 1 -instance : IsCharP UInt32 (2 ^ 32) where +instance : IsCharP UInt32 4294967296 where ofNat_eq_zero_iff {x} := by have : OfNat.ofNat x = UInt32.ofNat x := rfl simp [this, UInt32.ofNat_eq_iff_mod_eq_toNat] @@ -131,7 +131,7 @@ instance : CommRing UInt64 where pow_succ := UInt64.pow_succ ofNat_succ x := UInt64.ofNat_add x 1 -instance : IsCharP UInt64 (2 ^ 64) where +instance : IsCharP UInt64 18446744073709551616 where ofNat_eq_zero_iff {x} := by have : OfNat.ofNat x = UInt64.ofNat x := rfl simp [this, UInt64.ofNat_eq_iff_mod_eq_toNat] diff --git a/tests/lean/run/grind_som1.lean b/tests/lean/run/grind_som1.lean new file mode 100644 index 0000000000..d28d532fa9 --- /dev/null +++ b/tests/lean/run/grind_som1.lean @@ -0,0 +1,27 @@ +import Lean +import Init.Grind.CommRing.SOM + +open Lean.Grind +open Lean.Grind.CommRing + +-- Convenient RArray literals +elab tk:"#R[" ts:term,* "]" : term => do + let ts : Array Lean.Syntax := ts + let es ← ts.mapM fun stx => Lean.Elab.Term.elabTerm stx none + if h : 0 < es.size then + Lean.RArray.toExpr (← Lean.Meta.inferType es[0]!) id (Lean.RArray.ofArray es h) + else + throwErrorAt tk "RArray cannot be empty" + +example (x y : Int) : (x + y) * (x + y + 1) = x * (1 + y + x) + (y + 1 + x) * y := + let ctx := #R[x, y] + let lhs : Expr := .mul (.add (.var 0) (.var 1)) (.add (.add (.var 0) (.var 1)) (.num 1)) + let rhs : Expr := .add (.mul (.var 0) (.add (.add (.num 1) (.var 1)) (.var 0))) + (.mul (.add (.add (.var 1) (.num 1)) (.var 0)) (.var 1)) + Expr.eq_of_toPoly_eq ctx lhs rhs (Eq.refl true) + +example (x y : UInt8) : (128 * x + y) * 2 = y + y := + let ctx := #R[x, y] + let lhs : Expr := .mul (.add (.mul (.num 128) (.var 0)) (.var 1)) (.num 2) + let rhs : Expr := .add (.var 1) (.var 1) + Expr.eq_of_toPolyC_eq ctx lhs rhs (Eq.refl true)