From eaf46dfab1d41f45bb4d923e55686cc73f308610 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 16 Apr 2025 18:48:03 -0700 Subject: [PATCH] feat: add `Expr.toPoly` (#7992) This PR add a function for converting `CommRing` expressions into multivariate polynomials. Co-authored-by: Leonardo de Moura --- src/Init/Grind/CommRing/Basic.lean | 6 ++++ src/Init/Grind/CommRing/SOM.lean | 56 ++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/src/Init/Grind/CommRing/Basic.lean b/src/Init/Grind/CommRing/Basic.lean index bb8322717a..0e23e542bc 100644 --- a/src/Init/Grind/CommRing/Basic.lean +++ b/src/Init/Grind/CommRing/Basic.lean @@ -6,6 +6,7 @@ Authors: Kim Morrison prelude import Init.Data.Zero import Init.Data.Int.DivMod.Lemmas +import Init.Data.Int.Pow import Init.TacticsExtra /-! @@ -202,6 +203,11 @@ theorem intCast_mul (x y : Int) : ((x * y : Int) : α) = ((x : α) * (y : α)) : rw [Int.neg_mul_neg, intCast_neg, intCast_neg, neg_mul, mul_neg, neg_neg, intCast_nat_mul, intCast_ofNat, intCast_ofNat] +theorem intCast_pow (x : Int) (k : Nat) : ((x ^ k : Int) : α) = (x : α) ^ k := by + induction k + next => simp [pow_zero, Int.pow_zero, intCast_one] + next k ih => simp [pow_succ, Int.pow_succ, intCast_mul, *] + theorem pow_add (a : α) (k₁ k₂ : Nat) : a ^ (k₁ + k₂) = a^k₁ * a^k₂ := by induction k₂ next => simp [pow_zero, mul_one] diff --git a/src/Init/Grind/CommRing/SOM.lean b/src/Init/Grind/CommRing/SOM.lean index 44886272f5..926d8fdcb5 100644 --- a/src/Init/Grind/CommRing/SOM.lean +++ b/src/Init/Grind/CommRing/SOM.lean @@ -79,6 +79,9 @@ where | .leaf p => acc * p.denote ctx | .cons p m => go (acc * p.denote ctx) m +def Mon.ofVar (x : Var) : Mon := + .leaf { x, k := 1 } + def Mon.concat (m₁ m₂ : Mon) : Mon := match m₁ with | .leaf p => .cons p m₂ @@ -207,6 +210,12 @@ def Poly.denote [CommRing α] (ctx : Context α) (p : Poly) : α := | .num k => Int.cast k | .add k m p => Int.cast k * m.denote ctx + denote ctx p +def Poly.ofMon (m : Mon) : Poly := + .add 1 m (.num 0) + +def Poly.ofVar (x : Var) : Poly := + ofMon (Mon.ofVar x) + def Poly.addConst (p : Poly) (k : Int) : Poly := match p with | .num k' => .num (k' + k) @@ -280,6 +289,26 @@ where | .num k => acc.combine (p₂.mulConst k) | .add k m p₁ => go p₁ (acc.combine (p₂.mulMon k m)) +-- TODO: optimize +def Poly.pow (p : Poly) (k : Nat) : Poly := + match k with + | 0 => .num 1 + | 1 => p + | k+1 => p.mul (pow p k) + +def Expr.toPoly : Expr → Poly + | .num n => .num n + | .var x => Poly.ofVar x + | .add a b => a.toPoly.combine b.toPoly + | .mul a b => a.toPoly.mul b.toPoly + | .neg a => a.toPoly.mulConst (-1) + | .sub a b => a.toPoly.combine (b.toPoly.mulConst (-1)) + | .pow a k => + match a with + | .num n => .num (n^k) + | .var x => Poly.ofMon (.leaf {x, k}) + | _ => a.toPoly.pow k + theorem Power.denote_eq [CommRing α] (ctx : Context α) (p : Power) : p.denote ctx = p.x.denote ctx ^ p.k := by cases p <;> simp [Power.denote] <;> split <;> simp [pow_zero, pow_succ, one_mul] @@ -296,6 +325,10 @@ theorem Mon.denote'_eq_denote [CommRing α] (ctx : Context α) (m : Mon) cases m <;> simp [Mon.denote, Mon.denote'] next p m => apply denote'_go_eq_denote +theorem Mon.denote_ofVar [CommRing α] (ctx : Context α) (x : Var) + : denote ctx (ofVar x) = x.denote ctx := by + simp [denote, ofVar, Power.denote_eq, pow_succ, pow_zero, one_mul] + theorem Mon.denote_concat [CommRing α] (ctx : Context α) (m₁ m₂ : Mon) : denote ctx (concat m₁ m₂) = m₁.denote ctx * m₂.denote ctx := by induction m₁ <;> simp [concat, denote, *] @@ -382,6 +415,14 @@ theorem Mon.eq_of_revlex {m₁ m₂ : Mon} : revlex m₁ m₂ = .eq → m₁ = m theorem Mon.eq_of_grevlex {m₁ m₂ : Mon} : grevlex m₁ m₂ = .eq → m₁ = m₂ := by simp [grevlex, then_eq]; intro; apply eq_of_revlex +theorem Poly.denote_ofMon [CommRing α] (ctx : Context α) (m : Mon) + : denote ctx (ofMon m) = m.denote ctx := by + simp [ofMon, denote, intCast_one, intCast_zero, one_mul, add_zero] + +theorem Poly.denote_ofVar [CommRing α] (ctx : Context α) (x : Var) + : denote ctx (ofVar x) = x.denote ctx := by + simp [ofVar, denote_ofMon, Mon.denote_ofVar] + theorem Poly.denote_addConst [CommRing α] (ctx : Context α) (p : Poly) (k : Int) : (addConst p k).denote ctx = p.denote ctx + k := by fun_induction addConst <;> simp [addConst, denote, *] next => rw [intCast_add] @@ -442,6 +483,21 @@ theorem Poly.denote_mul [CommRing α] (ctx : Context α) (p₁ p₂ : Poly) : (mul p₁ p₂).denote ctx = p₁.denote ctx * p₂.denote ctx := by simp [mul, denote_mul_go, denote, intCast_zero, zero_add] +theorem Poly.denote_pow [CommRing α] (ctx : Context α) (p : Poly) (k : Nat) + : (pow p k).denote ctx = p.denote ctx ^ k := by + fun_induction pow <;> simp [pow, denote, intCast_one, pow_zero] + next => simp [pow_succ, pow_zero, one_mul] + next => simp [denote_mul, *, pow_succ, mul_comm] + +theorem Expr.denote_toPoly [CommRing α] (ctx : Context α) (e : Expr) + : e.toPoly.denote ctx = e.denote ctx := by + fun_induction toPoly + <;> simp [toPoly, denote, Poly.denote, Poly.denote_ofVar, Poly.denote_combine, + Poly.denote_mul, Poly.denote_mulConst, Poly.denote_pow, *] + next => rw [intCast_neg, neg_mul, intCast_one, one_mul] + next => rw [intCast_neg, neg_mul, intCast_one, one_mul, sub_eq_add_neg] + next => rw [intCast_pow] + next => simp [Poly.denote_ofMon, Mon.denote, Power.denote_eq] end CommRing end Lean.Grind