fix: CommRing multivariate polynomials (#8016)
This PR fixes several issues in the `CommRing` multivariate polynomial library: 1. Replaces the previous array type with the universe polymorphic `RArray`. 2. Properly eliminates cancelled monomials. 3. Sorts monomials in decreasing order. 4. Marks the parameter `p` of the `IsCharP` class as an output parameter. 5. Adds `LawfulBEq` instances for the types `Power`, `Mon`, and `Poly`.
This commit is contained in:
parent
f163758bcf
commit
27a7a0a2bd
4 changed files with 109 additions and 32 deletions
|
|
@ -217,7 +217,7 @@ end CommRing
|
|||
|
||||
open CommRing
|
||||
|
||||
class IsCharP (α : Type u) [CommRing α] (p : Nat) where
|
||||
class IsCharP (α : Type u) [CommRing α] (p : outParam Nat) where
|
||||
ofNat_eq_zero_iff (p) : ∀ (x : Nat), OfNat.ofNat (α := α) x = 0 ↔ x % p = 0
|
||||
|
||||
namespace IsCharP
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
|||
prelude
|
||||
import Init.Data.Nat.Lemmas
|
||||
import Init.Data.Ord
|
||||
import Init.Data.RArray
|
||||
import Init.Grind.CommRing.Basic
|
||||
|
||||
namespace Lean.Grind
|
||||
|
|
@ -23,16 +24,6 @@ inductive Expr where
|
|||
| pow (a : Expr) (k : Nat)
|
||||
deriving Inhabited, BEq
|
||||
|
||||
-- TODO: add support for universes to Lean.RArray
|
||||
inductive RArray (α : Type u) : Type u where
|
||||
| leaf : α → RArray α
|
||||
| branch : Nat → RArray α → RArray α → RArray α
|
||||
|
||||
def RArray.get (a : RArray α) (n : Nat) : α :=
|
||||
match a with
|
||||
| .leaf x => x
|
||||
| .branch p l r => if n < p then l.get n else r.get n
|
||||
|
||||
abbrev Context (α : Type u) := RArray α
|
||||
|
||||
def Var.denote {α} (ctx : Context α) (v : Var) : α :=
|
||||
|
|
@ -52,6 +43,10 @@ structure Power where
|
|||
k : Nat
|
||||
deriving BEq, Repr
|
||||
|
||||
instance : LawfulBEq Power where
|
||||
eq_of_beq {a} := by cases a <;> intro b <;> cases b <;> simp_all! [BEq.beq]
|
||||
rfl := by intro a; cases a <;> simp! [BEq.beq]
|
||||
|
||||
def Power.varLt (p₁ p₂ : Power) : Bool :=
|
||||
p₁.x.blt p₂.x
|
||||
|
||||
|
|
@ -67,6 +62,18 @@ inductive Mon where
|
|||
| cons (p : Power) (m : Mon)
|
||||
deriving BEq, Repr
|
||||
|
||||
instance : LawfulBEq Mon where
|
||||
eq_of_beq {a} := by
|
||||
induction a <;> intro b <;> cases b <;> simp_all! [BEq.beq]
|
||||
next p₁ p₂ => cases p₁ <;> cases p₂ <;> simp <;> intros <;> simp [*]
|
||||
next p₁ m₁ p₂ m₂ ih =>
|
||||
cases p₁ <;> cases p₂ <;> simp <;> intros <;> simp [*]
|
||||
next h => exact ih h
|
||||
rfl := by
|
||||
intro a
|
||||
induction a <;> simp! [BEq.beq]
|
||||
assumption
|
||||
|
||||
def Mon.denote {α} [CommRing α] (ctx : Context α) : Mon → α
|
||||
| .leaf p => p.denote ctx
|
||||
| .cons p m => p.denote ctx * denote ctx m
|
||||
|
|
@ -205,6 +212,20 @@ inductive Poly where
|
|||
| add (k : Int) (v : Mon) (p : Poly)
|
||||
deriving BEq
|
||||
|
||||
instance : LawfulBEq Poly where
|
||||
eq_of_beq {a} := by
|
||||
induction a <;> intro b <;> cases b <;> simp_all! [BEq.beq]
|
||||
intro h₁ h₂ h₃
|
||||
next m₁ p₁ _ m₂ p₂ ih =>
|
||||
replace h₂ : m₁ == m₂ := h₂
|
||||
simp [ih h₃, eq_of_beq h₂]
|
||||
rfl := by
|
||||
intro a
|
||||
induction a <;> simp! [BEq.beq]
|
||||
next k m p ih =>
|
||||
show m == m ∧ p == p
|
||||
simp [ih]
|
||||
|
||||
def Poly.denote [CommRing α] (ctx : Context α) (p : Poly) : α :=
|
||||
match p with
|
||||
| .num k => Int.cast k
|
||||
|
|
@ -216,10 +237,20 @@ def Poly.ofMon (m : Mon) : Poly :=
|
|||
def Poly.ofVar (x : Var) : Poly :=
|
||||
ofMon (Mon.ofVar x)
|
||||
|
||||
def Poly.isSorted : Poly → Bool
|
||||
| .num _ => true
|
||||
| .add _ _ (.num _) => true
|
||||
| .add _ m₁ (.add k m₂ p) => m₁.grevlex m₂ == .gt && (Poly.add k m₂ p).isSorted
|
||||
|
||||
def Poly.addConst (p : Poly) (k : Int) : Poly :=
|
||||
match p with
|
||||
bif k == 0 then
|
||||
p
|
||||
else
|
||||
go p
|
||||
where
|
||||
go : Poly → Poly
|
||||
| .num k' => .num (k' + k)
|
||||
| .add k' m p => .add k' m (addConst p k)
|
||||
| .add k' m p => .add k' m (go p)
|
||||
|
||||
def Poly.insert (k : Int) (m : Mon) (p : Poly) : Poly :=
|
||||
bif k == 0 then
|
||||
|
|
@ -232,13 +263,13 @@ where
|
|||
| .add k' m' p =>
|
||||
match m.grevlex m' with
|
||||
| .eq =>
|
||||
let k'' := k + k'
|
||||
bif k'' == 0 then
|
||||
let k := k + k'
|
||||
bif k == 0 then
|
||||
p
|
||||
else
|
||||
.add k'' m p
|
||||
| .lt => .add k m (.add k' m' p)
|
||||
| .gt => .add k' m' (go p)
|
||||
.add k m p
|
||||
| .gt => .add k m (.add k' m' p)
|
||||
| .lt => .add k' m' (go p)
|
||||
|
||||
def Poly.concat (p₁ p₂ : Poly) : Poly :=
|
||||
match p₁ with
|
||||
|
|
@ -264,7 +295,11 @@ def Poly.mulMon (k : Int) (m : Mon) (p : Poly) : Poly :=
|
|||
go p
|
||||
where
|
||||
go : Poly → Poly
|
||||
| .num k' => .add (k*k') m (.num 0)
|
||||
| .num k' =>
|
||||
bif k' == 0 then
|
||||
.num 0
|
||||
else
|
||||
.add (k*k') m (.num 0)
|
||||
| .add k' m' p => .add (k*k') (m.mul m') (go p)
|
||||
|
||||
def Poly.combine (p₁ p₂ : Poly) : Poly :=
|
||||
|
|
@ -285,8 +320,8 @@ where
|
|||
go fuel p₁ p₂
|
||||
else
|
||||
.add k m₁ (go fuel p₁ p₂)
|
||||
| .lt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂))
|
||||
| .gt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂)
|
||||
| .gt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂))
|
||||
| .lt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂)
|
||||
|
||||
def Poly.mul (p₁ : Poly) (p₂ : Poly) : Poly :=
|
||||
go p₁ (.num 0)
|
||||
|
|
@ -344,8 +379,8 @@ where
|
|||
p
|
||||
else
|
||||
.add k'' m p
|
||||
| .lt => .add k m (.add k' m' p)
|
||||
| .gt => .add k' m' (go k p)
|
||||
| .gt => .add k m (.add k' m' p)
|
||||
| .lt => .add k' m' (go k p)
|
||||
|
||||
def Poly.mulConstC (k : Int) (p : Poly) (c : Nat) : Poly :=
|
||||
let k := k % c
|
||||
|
|
@ -404,8 +439,8 @@ where
|
|||
go fuel p₁ p₂
|
||||
else
|
||||
.add k m₁ (go fuel p₁ p₂)
|
||||
| .lt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂))
|
||||
| .gt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂)
|
||||
| .gt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂))
|
||||
| .lt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂)
|
||||
|
||||
def Poly.mulC (p₁ : Poly) (p₂ : Poly) (c : Nat) : Poly :=
|
||||
go p₁ (.num 0)
|
||||
|
|
@ -556,9 +591,12 @@ theorem Poly.denote_ofVar {α} [CommRing α] (ctx : Context α) (x : Var)
|
|||
simp [ofVar, denote_ofMon, Mon.denote_ofVar]
|
||||
|
||||
theorem Poly.denote_addConst {α} [CommRing α] (ctx : Context α) (p : Poly) (k : Int) : (addConst p k).denote ctx = p.denote ctx + k := by
|
||||
fun_induction addConst <;> simp [addConst, denote, *]
|
||||
next => rw [intCast_add]
|
||||
next => simp [add_comm, add_left_comm, add_assoc]
|
||||
simp [addConst, cond_eq_if]; split
|
||||
next => simp [*, intCast_zero, add_zero]
|
||||
next =>
|
||||
fun_induction addConst.go <;> simp [addConst.go, denote, *]
|
||||
next => rw [intCast_add]
|
||||
next => simp [add_comm, add_left_comm, add_assoc]
|
||||
|
||||
theorem Poly.denote_insert {α} [CommRing α] (ctx : Context α) (k : Int) (m : Mon) (p : Poly)
|
||||
: (insert k m p).denote ctx = k * m.denote ctx + p.denote ctx := by
|
||||
|
|
@ -595,6 +633,7 @@ theorem Poly.denote_mulMon {α} [CommRing α] (ctx : Context α) (k : Int) (m :
|
|||
next => simp [denote, *, intCast_zero, zero_mul]
|
||||
next =>
|
||||
fun_induction mulMon.go <;> simp [mulMon.go, denote, *]
|
||||
next h => simp +zetaDelta at h; simp [*, intCast_zero, mul_zero]
|
||||
next => simp [intCast_mul, intCast_zero, add_zero, mul_comm, mul_left_comm, mul_assoc]
|
||||
next => simp [Mon.denote_mul, intCast_mul, left_distrib, mul_comm, mul_left_comm, mul_assoc]
|
||||
|
||||
|
|
@ -635,6 +674,11 @@ theorem Expr.denote_toPoly {α} [CommRing α] (ctx : Context α) (e : Expr)
|
|||
next => rw [intCast_pow]
|
||||
next => simp [Poly.denote_ofMon, Mon.denote, Power.denote_eq]
|
||||
|
||||
theorem Expr.eq_of_toPoly_eq {α} [CommRing α] (ctx : Context α) (a b : Expr) (h : a.toPoly == b.toPoly) : a.denote ctx = b.denote ctx := by
|
||||
have h := congrArg (Poly.denote ctx) (eq_of_beq h)
|
||||
simp [denote_toPoly] at h
|
||||
assumption
|
||||
|
||||
theorem Poly.denote_addConstC {α c} [CommRing α] [IsCharP α c] (ctx : Context α) (p : Poly) (k : Int) : (addConstC p k c).denote ctx = p.denote ctx + k := by
|
||||
fun_induction addConstC <;> simp [addConstC, denote, *]
|
||||
next => rw [IsCharP.intCast_emod, intCast_add]
|
||||
|
|
@ -747,5 +791,11 @@ theorem Expr.denote_toPolyC {α c} [CommRing α] [IsCharP α c] (ctx : Context
|
|||
next => rw [IsCharP.intCast_emod, intCast_pow]
|
||||
next => simp [Poly.denote_ofMon, Mon.denote, Power.denote_eq]
|
||||
|
||||
theorem Expr.eq_of_toPolyC_eq {α c} [CommRing α] [IsCharP α c] (ctx : Context α) (a b : Expr)
|
||||
(h : a.toPolyC c == b.toPolyC c) : a.denote ctx = b.denote ctx := by
|
||||
have h := congrArg (Poly.denote ctx) (eq_of_beq h)
|
||||
simp [denote_toPolyC] at h
|
||||
assumption
|
||||
|
||||
end CommRing
|
||||
end Lean.Grind
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ instance : CommRing UInt8 where
|
|||
pow_succ := UInt8.pow_succ
|
||||
ofNat_succ x := UInt8.ofNat_add x 1
|
||||
|
||||
instance : IsCharP UInt8 (2 ^ 8) where
|
||||
instance : IsCharP UInt8 256 where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = UInt8.ofNat x := rfl
|
||||
simp [this, UInt8.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
|
@ -91,7 +91,7 @@ instance : CommRing UInt16 where
|
|||
pow_succ := UInt16.pow_succ
|
||||
ofNat_succ x := UInt16.ofNat_add x 1
|
||||
|
||||
instance : IsCharP UInt16 (2 ^ 16) where
|
||||
instance : IsCharP UInt16 65536 where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = UInt16.ofNat x := rfl
|
||||
simp [this, UInt16.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
|
@ -111,7 +111,7 @@ instance : CommRing UInt32 where
|
|||
pow_succ := UInt32.pow_succ
|
||||
ofNat_succ x := UInt32.ofNat_add x 1
|
||||
|
||||
instance : IsCharP UInt32 (2 ^ 32) where
|
||||
instance : IsCharP UInt32 4294967296 where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = UInt32.ofNat x := rfl
|
||||
simp [this, UInt32.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
|
@ -131,7 +131,7 @@ instance : CommRing UInt64 where
|
|||
pow_succ := UInt64.pow_succ
|
||||
ofNat_succ x := UInt64.ofNat_add x 1
|
||||
|
||||
instance : IsCharP UInt64 (2 ^ 64) where
|
||||
instance : IsCharP UInt64 18446744073709551616 where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = UInt64.ofNat x := rfl
|
||||
simp [this, UInt64.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
|
|
|||
27
tests/lean/run/grind_som1.lean
Normal file
27
tests/lean/run/grind_som1.lean
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import Lean
|
||||
import Init.Grind.CommRing.SOM
|
||||
|
||||
open Lean.Grind
|
||||
open Lean.Grind.CommRing
|
||||
|
||||
-- Convenient RArray literals
|
||||
elab tk:"#R[" ts:term,* "]" : term => do
|
||||
let ts : Array Lean.Syntax := ts
|
||||
let es ← ts.mapM fun stx => Lean.Elab.Term.elabTerm stx none
|
||||
if h : 0 < es.size then
|
||||
Lean.RArray.toExpr (← Lean.Meta.inferType es[0]!) id (Lean.RArray.ofArray es h)
|
||||
else
|
||||
throwErrorAt tk "RArray cannot be empty"
|
||||
|
||||
example (x y : Int) : (x + y) * (x + y + 1) = x * (1 + y + x) + (y + 1 + x) * y :=
|
||||
let ctx := #R[x, y]
|
||||
let lhs : Expr := .mul (.add (.var 0) (.var 1)) (.add (.add (.var 0) (.var 1)) (.num 1))
|
||||
let rhs : Expr := .add (.mul (.var 0) (.add (.add (.num 1) (.var 1)) (.var 0)))
|
||||
(.mul (.add (.add (.var 1) (.num 1)) (.var 0)) (.var 1))
|
||||
Expr.eq_of_toPoly_eq ctx lhs rhs (Eq.refl true)
|
||||
|
||||
example (x y : UInt8) : (128 * x + y) * 2 = y + y :=
|
||||
let ctx := #R[x, y]
|
||||
let lhs : Expr := .mul (.add (.mul (.num 128) (.var 0)) (.var 1)) (.num 2)
|
||||
let rhs : Expr := .add (.var 1) (.var 1)
|
||||
Expr.eq_of_toPolyC_eq ctx lhs rhs (Eq.refl true)
|
||||
Loading…
Add table
Reference in a new issue