feat: NoZeroNatDivisors helper class for grind (#8111)

This PR adds the helper type class `NoZeroNatDivisors` for the
commutative ring procedure in `grind`. Core only implements it for
`Int`. It can be instantiated in Mathlib for any type `A` that
implements `NoZeroSMulDivisors Nat A`.
See `findSimp?` and `PolyDerivation` for details on how this instance
impacts the commutative ring procedure.
This commit is contained in:
Leonardo de Moura 2025-04-26 08:14:27 -07:00 committed by GitHub
parent 18f8a18bfc
commit d81a922a20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 181 additions and 82 deletions

View file

@ -317,4 +317,27 @@ theorem natCast_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
end IsCharP
/--
Special case of Mathlib's `NoZeroSMulDivisors Nat α`.
-/
class NoZeroNatDivisors (α : Type u) [CommRing α] where
no_zero_nat_divisors : ∀ (k : Nat) (a : α), k ≠ 0 → OfNat.ofNat (α := α) k * a = 0 → a = 0
export NoZeroNatDivisors (no_zero_nat_divisors)
theorem no_zero_int_divisors (α : Type u) [CommRing α] [NoZeroNatDivisors α] {k : Int} (a : α)
: k ≠ 0 → k * a = 0 → a = 0 := by
match k with
| (k : Nat) =>
simp [intCast_natCast]
intro h₁ h₂
replace h₁ : k ≠ 0 := by intro h; simp [h] at h₁
exact no_zero_nat_divisors k a h₁ h₂
| -(k+1 : Nat) =>
rw [Int.natCast_add, ← Int.natCast_add, intCast_neg, intCast_natCast]
intro _ h
replace h := congrArg (-·) h; simp at h
rw [← neg_mul, neg_neg, neg_zero] at h
exact no_zero_nat_divisors (k+1) a (Nat.succ_ne_zero _) h
end Lean.Grind

View file

@ -29,4 +29,14 @@ instance : CommRing Int where
instance : IsCharP Int 0 where
ofNat_eq_zero_iff {x} := by erw [Int.ofNat_eq_zero]; simp
instance : NoZeroNatDivisors Int where
no_zero_nat_divisors k a h₁ h₂ := by
cases Int.mul_eq_zero.mp h₂
next h =>
rw [← Int.natCast_zero] at h
have h : (k : Int).toNat = (↑0 : Int).toNat := congrArg Int.toNat h;
simp at h
contradiction
next => assumption
end Lean.Grind

View file

@ -21,6 +21,10 @@ instance [ToString α] : ToString (LOption α) where
| .undef => "undef"
| .some a => "(some " ++ toString a ++ ")"
def LOption.toOption : LOption α → Option α
| .some a => .some a
| _ => .none
end Lean
def Option.toLOption {α : Type u} : Option α → Lean.LOption α

View file

@ -71,4 +71,7 @@ where
def EqCnstr.denoteExpr (c : EqCnstr) : RingM Expr := do
mkEq (← c.p.denoteExpr) (← denoteNum 0)
def PolyDerivation.denoteExpr (d : PolyDerivation) : RingM Expr := do
d.p.denoteExpr
end Lean.Meta.Grind.Arith.CommRing

View file

@ -38,39 +38,68 @@ private def toRingExpr? (e : Expr) : RingM (Option RingExpr) := do
/--
Returns `some c`, where `c` is an equation from the basis whose leading monomial divides `m`.
If `unitOnly` is true, only equations with a unit leading coefficient are considered.
Remark: if the current ring does not satisfy the property
```
∀ (k : Nat) (a : α), k ≠ 0 → OfNat.ofNat (α := α) k * a = 0 → a = 0
```
then the leading coefficient of the equation must also divide `k`
-/
def _root_.Lean.Grind.CommRing.Mon.findSimp? (m : Mon) (unitOnly : Bool := false) : RingM (Option EqCnstr) :=
go m
where
go : Mon → RingM (Option EqCnstr)
def _root_.Lean.Grind.CommRing.Mon.findSimp? (k : Int) (m : Mon) : RingM (Option EqCnstr) := do
let noZeroDiv ← noZeroDivisors
let rec go : Mon → RingM (Option EqCnstr)
| .unit => return none
| .mult pw m' => do
for c in (← getRing).varToBasis[pw.x]! do
if !unitOnly || c.p.lc.natAbs == 1 then
if noZeroDiv || (c.p.lc k) then
if c.p.divides m then
return some c
go m'
go m
/--
Returns `some c`, where `c` is an equation from the basis whose leading monomial divides some
monomial in `p`.
If `unitOnly` is true, only equations with a unit leading coefficient are considered.
-/
def _root_.Lean.Grind.CommRing.Poly.findSimp? (p : Poly) (unitOnly : Bool := false) : RingM (Option EqCnstr) := do
def _root_.Lean.Grind.CommRing.Poly.findSimp? (p : Poly) : RingM (Option EqCnstr) := do
match p with
| .num _ => return none
| .add _ m p =>
match (← m.findSimp? unitOnly) with
| .add k m p =>
match (← m.findSimp? k) with
| some c => return some c
| none => p.findSimp? unitOnly
| none => p.findSimp?
/-- Simplifies `c` using `c'`. -/
def EqCnstr.simplify1 (c c' : EqCnstr) : RingM (Option EqCnstr) := do
let some r := c'.p.simp? c.p (← nonzeroChar?) | return none
let c := { c with
/-- Simplifies `d.p` using `c`, and returns an extended polynomial derivation. -/
def PolyDerivation.simplify1 (d : PolyDerivation) (c : EqCnstr) : RingM (Option PolyDerivation) := do
let some r := d.p.simp? c.p (← nonzeroChar?) | return none
trace_goal[grind.ring.simp] "{← r.p.denoteExpr}"
return some <| .step r.p r.k₁ d r.k₂ r.m₂ c
/-- Simplifies `d.p` using `c` until it is not applicable anymore, and returns an extended polynomial derivation. -/
def PolyDerivation.simplifyWith (d : PolyDerivation) (c : EqCnstr) : RingM PolyDerivation := do
let mut d := d
repeat
checkSystem "ring"
let some r ← d.simplify1 c | return d
trace_goal[grind.debug.ring.simp] "simplifying{indentD (← d.denoteExpr)}\nwith{indentD (← c.denoteExpr)}"
d := r
return d
/-- Simplified `d.p` using the current basis, and returns the extended polynomial derivation. -/
def PolyDerivation.simplify (d : PolyDerivation) : RingM PolyDerivation := do
let mut d := d
repeat
let some c ← d.p.findSimp? |
trace_goal[grind.debug.ring.simp] "simplified{indentD (← d.denoteExpr)}"
return d
d ← d.simplifyWith c
return d
/-- Simplifies `c₁` using `c₂`. -/
def EqCnstr.simplify1 (c₁ c₂ : EqCnstr) : RingM (Option EqCnstr) := do
let some r := c₁.p.simp? c₂.p (← nonzeroChar?) | return none
let c := { c₁ with
p := r.p
h := .simp c' c r.k₁ r.k₂ r.m
h := .simp r.k₁ c₁ r.k₂ r.m₂ c₂
}
trace_goal[grind.ring.simp] "{← c.p.denoteExpr}"
return some c
@ -101,7 +130,7 @@ def EqCnstr.checkConstant (c : EqCnstr) : RingM Bool := do
if k == 0 then
trace_goal[grind.ring.assert.trivial] "{← c.denoteExpr}"
else if (← hasChar) then
setInconsistent c
setUnsatEq c
else
-- Remark: we currently don't do anything if the characteristic is not known.
trace_goal[grind.ring.assert.discard] "{← c.denoteExpr}"

View file

@ -111,10 +111,10 @@ def Poly.spol (p₁ p₂ : Poly) (char? : Option Nat := none) : SPolResult :=
| _, _ => {}
/--
Result of simplifying a polynomial `p₂` using a polynomial `p₁`.
Result of simplifying a polynomial `p₁` using a polynomial `p₂`.
The simplification rewrites the first monomial of `p` that can be divided
by the leading monomial of `p`.
The simplification rewrites the first monomial of `p` that can be divided
by the leading monomial of `p`.
-/
structure SimpResult where
/-- The resulting simplified polynomial after rewriting. -/
@ -123,43 +123,43 @@ structure SimpResult where
k₁ : Int := 0
/-- The integer coefficient multiplied with polynomial `p₂` during rewriting. -/
k₂ : Int := 0
/-- The monomial factor applied to polynomial `p`. -/
m : Mon := .unit
/-- The monomial factor applied to polynomial `p`. -/
m : Mon := .unit
/--
Simplifies polynomial `p₂` using polynomial `p₁` by rewriting.
Simplifies polynomial `p₁` using polynomial `p₂` by rewriting.
This function attempts to rewrite `p` by eliminating the first occurrence of
the leading monomial of `p`.
This function attempts to rewrite `p` by eliminating the first occurrence of
the leading monomial of `p`.
Remark: if `char? = some c`, then `c` is the characteristic of the ring.
-/
def Poly.simp? (p₁ p₂ : Poly) (char? : Option Nat := none) : Option SimpResult :=
match p with
| .add k₁' m₁ p₁ =>
let rec go? (p : Poly) : Option SimpResult :=
match p with
| .add k₂' m₂ p₂ =>
if m₁.divides m₂ then
let m := m₂.div m₁
match p with
| .add k₂' m₂ p₂ =>
let rec go? (p : Poly) : Option SimpResult :=
match p with
| .add k₁' m₁ p₁ =>
if m₂.divides m₁ then
let m₂ := m₁.div m₂
let g := Nat.gcd k₁'.natAbs k₂'.natAbs
let k₁ := -k₂'/g
let k₂ := k₁'/g
let p := (p₁.mulMon' k₁ m char?).combine' (p₂.mulConst' k₂ char?) char?
some { p, k₁, k₂, m }
else if let some r := go? p then
let k₁ := k₂'/g
let k₂ := -k₁'/g
let p := (p₂.mulMon' k₂ m₂ char?).combine' (p₁.mulConst' k₁ char?) char?
some { p, k₁, k₂, m }
else if let some r := go? p then
if let some char := char? then
let k := (k₂'*r.k₂) % char
let k := (k₁'*r.k₁) % char
if k == 0 then
some r
else
some { r with p := .add k m r.p }
some { r with p := .add k m r.p }
else
some { r with p := .add (k₂'*r.k₂) m₂ r.p }
some { r with p := .add (k₁'*r.k₁) m₁ r.p }
else
none
| .num _ => none
go? p
go? p
| _ => none
def Poly.degree : Poly → Nat

View file

@ -126,8 +126,8 @@ partial def EqCnstr.toPreNullCert (c : EqCnstr) : ProofM PreNullCert := caching
let h ← mkEqProof a b
modify fun s => { s with hyps := s.hyps.push { h, lhs, rhs } }
return PreNullCert.unit i (i+1)
| .superpose c₁ c₂ k₁ k₂ m₁ m₂ => (← c₁.toPreNullCert).combine k₁ m₁ k₂ m₂ (← c₂.toPreNullCert)
| .simp c₁ c₂ k₁ k₂ m => (← c₁.toPreNullCert).combine k₁ m k₂ .unit (← c₂.toPreNullCert)
| .superpose k₁ m₁ c₁ k₂ m₂ c₂ => (← c₁.toPreNullCert).combine k₁ m₁ k₂ m₂ (← c₂.toPreNullCert)
| .simp k₁ c₁ k₂ m₂ c₂ => (← c₁.toPreNullCert).combine k₁ .unit k₂ m₂ (← c₂.toPreNullCert)
| .mul k c => (← c.toPreNullCert).mul k
| .div k c => (← c.toPreNullCert).div k
@ -150,7 +150,7 @@ def NullCertExt.check (c : EqCnstr) (nc : NullCertExt) : RingM Bool := do
let p₂ ← nc.toPoly
return p₁ == p₂
def setInconsistent (c : EqCnstr) : RingM Unit := do
def setUnsatEq (c : EqCnstr) : RingM Unit := do
trace_goal[grind.ring.assert.unsat] "{← c.denoteExpr}"
let nc ← c.mkNullCertExt
trace_goal[grind.ring.assert.unsat] "{nc.d}*({← c.p.denoteExpr}), {← (← nc.toPoly).denoteExpr}"
@ -165,10 +165,12 @@ private def mkLemmaPrefix (declName declNameC : Name) : RingM Expr := do
else
return mkApp3 (mkConst declName [ring.u]) ring.type ring.commRingInst ctx
-- TODO: delete
def setNeUnsat (a b : Expr) (ra rb : RingExpr) : RingM Unit := do
let h ← mkLemmaPrefix ``Grind.CommRing.ne_unsat ``Grind.CommRing.ne_unsatC
closeGoal <| mkApp4 h (toExpr ra) (toExpr rb) reflBoolTrue (← mkDiseqProof a b)
-- TODO: delete
def setEqUnsat (k : Int) (a b : Expr) (ra rb : RingExpr) : RingM Unit := do
let mut h ← mkLemmaPrefix ``Grind.CommRing.eq_unsat ``Grind.CommRing.eq_unsatC
let (charInst, c) ← getCharInst

View file

@ -105,6 +105,9 @@ where
| trace_goal[grind.ring] "found instance for{indentExpr charType}\nbut characteristic is not a natural number"; pure none
trace_goal[grind.ring] "characteristic: {n}"
pure <| some (charInst, n)
let noZeroDivType := mkApp2 (mkConst ``Grind.NoZeroNatDivisors [u]) type commRingInst
let noZeroDivInst? := (← trySynthInstance noZeroDivType).toOption
trace_goal[grind.ring] "NoZeroNatDivisors available: {noZeroDivInst?.isSome}"
let addFn ← getAddFn type u commRingInst
let mulFn ← getMulFn type u commRingInst
let subFn ← getSubFn type u commRingInst
@ -113,7 +116,7 @@ where
let intCastFn ← getIntCastFn type u commRingInst
let natCastFn ← getNatCastFn type u commRingInst
let id := (← get').rings.size
let ring : Ring := { id, type, u, commRingInst, charInst?, addFn, mulFn, subFn, negFn, powFn, intCastFn, natCastFn }
let ring : Ring := { id, type, u, commRingInst, charInst?, noZeroDivInst?, addFn, mulFn, subFn, negFn, powFn, intCastFn, natCastFn }
modify' fun s => { s with rings := s.rings.push ring }
return some id

View file

@ -26,8 +26,8 @@ structure EqCnstr where
inductive EqCnstrProof where
| core (a b : Expr) (ra rb : RingExpr)
| superpose (c₁ c₂ : EqCnstr) (k₁ k₂ : Int) (m₁ m₂ : Mon)
| simp (c₁ c₂ : EqCnstr) (k₁ k₂ : Int) (m : Mon)
| superpose (k₁ : Int) (m₁ : Mon) (c₁ : EqCnstr) (k₂ : Int) (m₂ : Mon) (c₂ : EqCnstr)
| simp (k₁ : Int) (c₁ : EqCnstr) (k₂ : Int) (m₂ : Mon) (c₂ : EqCnstr)
| mul (k : Int) (e : EqCnstr)
| div (k : Int) (e : EqCnstr)
@ -47,7 +47,7 @@ protected def EqCnstr.compare (c₁ c₂ : EqCnstr) : Ordering :=
abbrev Queue : Type := RBTree EqCnstr EqCnstr.compare
/--
A simplification chain.
A polynomial equipped with a chain of rewrite steps that justifies its equality to the original input.
From an input polynomial `p`, we use equations (i.e., `EqCnstr`) as rewriting rules.
For example, consider the following sequence of rewrites for the input polynomial `x^2 + x*y`
using the equations `x - 1 = 0` (`c₁`) and `y - 2 = 0` (`c₂`).
@ -72,15 +72,16 @@ for
```
because `x-1 = 0` and `y-2=0`
-/
inductive SimpChain where
inductive PolyDerivation where
| input (p : Poly)
| /--
```
p = k₁*s.getPoly + k₂*m*c.p
p = k₁*d.getPoly + k₂*m₂*c.p
```
The coefficient `k₁` is used because the leading monomial in `c` may not be monic.
Thus, if we follow the chain back to the input polynomial, we have that
`p = C * input_p` for a `C` that is equal to the product of all `k₁`s in the chain.
We have that `C ≠ 1` only if the ring does not implement `NoZeroNatDivisors`.
Here is a small example where we simplify `x+y` using the equations
`2*x - 1 = 0` (`c₁`), `3*y - 1 = 0` (`c₂`), and `6*z + 5 = 0` (`c₃`)
```
@ -102,56 +103,57 @@ inductive SimpChain where
```
0 = 6*(x + y + z)
```
If we have a commutative ring where
Recall that if the ring implement `NoZeroNatDivisors`, then the following property holds:
```
∀ (k : Int) (a : α), k ≠ 0 → (intCast k) * a = 0 → a = 0
```
grind can deduce that `x+y+z = 0`
-/
simp (p : Poly) (c : EqCnstr) (k₁ : Int) (k₂ : Int) (m : Mon) (s : SimpChain)
step (p : Poly) (k₁ : Int) (d : PolyDerivation) (k₂ : Int) (m₂ : Mon) (c : EqCnstr)
def SimpChain.getPoly : SimpChain → Poly
def PolyDerivation.p : PolyDerivation → Poly
| .input p => p
| .simp p .. => p
| .step p .. => p
/-- State for each `CommRing` processed by this module. -/
structure Ring where
id : Nat
type : Expr
id : Nat
type : Expr
/-- Cached `getDecLevel type` -/
u : Level
u : Level
/-- `CommRing` instance for `type` -/
commRingInst : Expr
commRingInst : Expr
/-- `IsCharP` instance for `type` if available. -/
charInst? : Option (Expr × Nat) := .none
addFn : Expr
mulFn : Expr
subFn : Expr
negFn : Expr
powFn : Expr
intCastFn : Expr
natCastFn : Expr
charInst? : Option (Expr × Nat) := .none
/-- `NoZeroNatDivisors` instance for `type` if available. -/
noZeroDivInst? : Option Expr := .none
addFn : Expr
mulFn : Expr
subFn : Expr
negFn : Expr
powFn : Expr
intCastFn : Expr
natCastFn : Expr
/--
Mapping from variables to their denotations.
Remark each variable can be in only one ring.
-/
vars : PArray Expr := {}
vars : PArray Expr := {}
/-- Mapping from `Expr` to a variable representing it. -/
varMap : PHashMap ENodeKey Var := {}
varMap : PHashMap ENodeKey Var := {}
/-- Mapping from Lean expressions to their representations as `RingExpr` -/
denote : PHashMap ENodeKey RingExpr := {}
denote : PHashMap ENodeKey RingExpr := {}
/-- Next unique id for `EqCnstr`s. -/
nextId : Nat := 0
nextId : Nat := 0
/-- Number of "steps": simplification and superposition. -/
steps : Nat := 0
steps : Nat := 0
/-- Equations to process. -/
queue : Queue := {}
queue : Queue := {}
/--
Mapping from variables `x` to equations such that the smallest variable
in the leading monomial is `x`.
-/
varToBasis : PArray (List EqCnstr) := {}
varToBasis : PArray (List EqCnstr) := {}
deriving Inhabited
/-- State for all `CommRing` types detected by `grind`. -/

View file

@ -61,6 +61,15 @@ def nonzeroCharInst? : RingM (Option (Expr × Nat)) := do
return some (inst, c)
return none
/--
Returns `true` if the current ring satifies the property
```
∀ (k : Nat) (a : α), k ≠ 0 → OfNat.ofNat (α := α) k * a = 0 → a = 0
```
-/
def noZeroDivisors : RingM Bool := do
return (← getRing).noZeroDivInst?.isSome
/-- Returns `true` if the current ring has a `IsCharP` instance. -/
def hasChar : RingM Bool := do
return (← getRing).charInst?.isSome

View file

@ -13,12 +13,26 @@ example (x : Int) : (x + 1)*(x - 1) = x^2 - 1 := by
example (x : UInt8) : (x + 16)*(x - 16) = x^2 := by
grind +ring
/--
info: [grind.ring] new ring: Int
[grind.ring] characteristic: 0
[grind.ring] NoZeroNatDivisors available: true
-/
#guard_msgs (info) in
set_option trace.grind.ring true in
example (x : Int) : (x + 1)^2 - 1 = x^2 + 2*x := by
grind +ring
example (x : BitVec 8) : (x + 16)*(x - 16) = x^2 := by
grind +ring
/--
info: [grind.ring] new ring: BitVec 8
[grind.ring] characteristic: 256
[grind.ring] NoZeroNatDivisors available: false
-/
#guard_msgs (info) in
set_option trace.grind.ring true in
example (x : BitVec 8) : (x + 1)^2 - 1 = x^2 + 2*x := by
grind +ring

View file

@ -49,16 +49,16 @@ def simp? (p₁ p₂ : Poly) : Option Poly :=
partial def simp' (p₁ p₂ : Poly) : Poly :=
if let some r := p₁.simp? p₂ then
assert! r.p == (p₁.mulMon r.k₁ r.m).combine (p₂.mulConst r.k₂)
simp' p₁ r.p
assert! r.p == (p₂.mulMon r.k₂ r.m₂).combine (p₁.mulConst r.k₁)
simp' r.p p₂
else
p
p
def check_simp' (e₁ e₂ r : Expr) : Bool :=
r.toPoly == simp' e₁.toPoly e₂.toPoly
example : check_simp' (x*y - y) (x^2*y - 1) (y - 1) := by native_decide
example : check_simp' (2*x + 1) (x^2 + x + 1) 3 := by native_decide
example : check_simp' (2*x + 1) (3*x^2 + x + y + 1) (4*y + 5) := by native_decide
example : check_simp' (2*x + y) (3*x^2 + x + y + 1) (3*y^2 + 2*y + 4) := by native_decide
example : check_simp' (2*x + 1) (z^4 + w^3 + x^2 + x + 1) (4*z^4 + 4*w^3 + 3) := by native_decide
example : check_simp' (x^2*y - 1) (x*y - y) (y - 1) := by native_decide
example : check_simp' (x^2 + x + 1) (2*x + 1) 3 := by native_decide
example : check_simp' (3*x^2 + x + y + 1) (2*x + 1) (4*y + 5) := by native_decide
example : check_simp' (3*x^2 + x + y + 1) (2*x + y) (3*y^2 + 2*y + 4) := by native_decide
example : check_simp' (z^4 + w^3 + x^2 + x + 1) (2*x + 1) (4*z^4 + 4*w^3 + 3) := by native_decide