fix: limit ring solver polynomial degree in grind (#13585)
This PR adds a `ringMaxDegree` configuration option (default `1024`) that bounds the maximum degree of polynomials processed by the `grind` ring solver. Equality constraints whose polynomial exceeds this threshold are discarded (with an issue reported once per goal), preventing pathological degree explosion on inputs such as `r ^ (2 ^ 250 - 1)`. This PR also introduces `Poly.simpM?`, a monadic version of `Poly.simp?` built on the existing safe arithmetic primitives (`mulMonM`, `combineM`, `mulConstM`) in `Grind.Arith.CommRing.SafePoly`. The previous reflection-oriented `Poly.simp?` in `Sym.Arith.Poly` lacked the abort mechanisms needed during proof search, so the simplification path used by `EqCnstr` now goes through the safe variant. A regression test `tests/elab/grind_ring_degree_explosion.lean` ensures `grind` fails quickly on high-degree problems.
This commit is contained in:
parent
19baa470e5
commit
427e3bcdbc
8 changed files with 79 additions and 73 deletions
|
|
@ -116,6 +116,10 @@ structure Config where
|
|||
-/
|
||||
ringSteps := 100000
|
||||
/--
|
||||
Maximum degree of polynomials processed by the `ring` solver.
|
||||
-/
|
||||
ringMaxDegree := 1024
|
||||
/--
|
||||
When `true` (default: `true`), uses procedure for handling linear arithmetic for `IntModule`, and
|
||||
`CommRing`.
|
||||
-/
|
||||
|
|
|
|||
|
|
@ -128,58 +128,6 @@ def Poly.spol (p₁ p₂ : Poly) (char? : Option Nat := none) : SPolResult :=
|
|||
{ spol, m₁, m₂, k₁ := c₁, k₂ := c₂ }
|
||||
| _, _ => {}
|
||||
|
||||
/--
|
||||
Result of simplifying a polynomial `p₁` using a polynomial `p₂`.
|
||||
|
||||
The simplification rewrites the first monomial of `p₁` that can be divided
|
||||
by the leading monomial of `p₂`.
|
||||
-/
|
||||
structure SimpResult where
|
||||
/-- The resulting simplified polynomial after rewriting. -/
|
||||
p : Poly := .num 0
|
||||
/-- The integer coefficient multiplied with polynomial `p₁` in the rewriting step. -/
|
||||
k₁ : Int := 0
|
||||
/-- The integer coefficient multiplied with polynomial `p₂` during rewriting. -/
|
||||
k₂ : Int := 0
|
||||
/-- The monomial factor applied to polynomial `p₂`. -/
|
||||
m₂ : Mon := .unit
|
||||
|
||||
/--
|
||||
Simplifies polynomial `p₁` using polynomial `p₂` by rewriting.
|
||||
|
||||
This function attempts to rewrite `p₁` by eliminating the first occurrence of
|
||||
the leading monomial of `p₂`.
|
||||
|
||||
Remark: if `char? = some c`, then `c` is the characteristic of the ring.
|
||||
-/
|
||||
def Poly.simp? (p₁ p₂ : Poly) (char? : Option Nat := none) : Option SimpResult :=
|
||||
match p₂ with
|
||||
| .add k₂' m₂ p₂ =>
|
||||
let rec go? (p₁ : Poly) : Option SimpResult :=
|
||||
match p₁ with
|
||||
| .add k₁' m₁ p₁ =>
|
||||
if m₂.divides m₁ then
|
||||
let m₂ := m₁.div m₂
|
||||
let g := Nat.gcd k₁'.natAbs k₂'.natAbs
|
||||
let k₁ := k₂'/g
|
||||
let k₂ := -k₁'/g
|
||||
let p := (p₂.mulMon' k₂ m₂ char?).combine' (p₁.mulConst' k₁ char?) char?
|
||||
some { p, k₁, k₂, m₂ }
|
||||
else if let some r := go? p₁ then
|
||||
if let some char := char? then
|
||||
let k := (k₁'*r.k₁) % char
|
||||
if k == 0 then
|
||||
some r
|
||||
else
|
||||
some { r with p := .add k m₁ r.p }
|
||||
else
|
||||
some { r with p := .add (k₁'*r.k₁) m₁ r.p }
|
||||
else
|
||||
none
|
||||
| .num _ => none
|
||||
go? p₁
|
||||
| _ => none
|
||||
|
||||
def Poly.degree : Poly → Nat
|
||||
| .num _ => 0
|
||||
| .add _ m _ => m.degree
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ def _root_.Lean.Grind.CommRing.Poly.findSimp? (p : Poly) : RingM (Option EqCnstr
|
|||
|
||||
/-- Simplifies `d.p` using `c`, and returns an extended polynomial derivation. -/
|
||||
def PolyDerivation.simplifyWith (d : PolyDerivation) (c : EqCnstr) : RingM PolyDerivation := do
|
||||
let some r := d.p.simp? c.p (← nonzeroChar?) | return d
|
||||
let some r ← d.p.simpM? c.p | return d
|
||||
incSteps r.p.numTerms
|
||||
trace_goal[grind.ring.simp] "{← r.p.denoteExpr}"
|
||||
return .step r.p r.k₁ d r.k₂ r.m₂ c
|
||||
|
|
@ -132,7 +132,7 @@ def PolyDerivation.simplify (d : PolyDerivation) : RingM PolyDerivation := do
|
|||
|
||||
/-- Simplifies `c₁` using `c₂`. -/
|
||||
def EqCnstr.simplifyWithCore (c₁ c₂ : EqCnstr) : RingM (Option EqCnstr) := do
|
||||
let some r := c₁.p.simp? c₂.p (← nonzeroChar?) | return none
|
||||
let some r ← c₁.p.simpM? c₂.p | return none
|
||||
let c := { c₁ with
|
||||
p := r.p
|
||||
h := .simp r.k₁ c₁ r.k₂ r.m₂ c₂
|
||||
|
|
@ -221,6 +221,7 @@ def addToBasisCore (c : EqCnstr) : RingM Unit := do
|
|||
def EqCnstr.addToQueue (c : EqCnstr) : RingM Unit := do
|
||||
if (← checkMaxSteps) then return ()
|
||||
trace_goal[grind.ring.assert.queue] "{← c.denoteExpr}"
|
||||
if (← checkMaxDegree c.p) then return () -- discard
|
||||
modifyCommRing fun s => { s with queue := s.queue.insert c }
|
||||
|
||||
def EqCnstr.superposeWith (c : EqCnstr) : RingM Unit := do
|
||||
|
|
@ -307,6 +308,7 @@ private def checkNumEq0Updated : RingM Unit := do
|
|||
checkNumEq0Updated
|
||||
|
||||
def EqCnstr.addToBasis (c : EqCnstr) : RingM Unit := do
|
||||
if (← checkMaxDegree c.p) then return () -- discard
|
||||
withCheckingNumEq0 do
|
||||
let some c ← c.simplifyAndCheck | return ()
|
||||
c.addToBasisAfterSimp
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ module
|
|||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.SynthInstance
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing
|
||||
import Lean.Meta.Sym.Arith.Poly
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
|
|
@ -14,6 +15,15 @@ open Sym.Arith
|
|||
def checkMaxSteps : GoalM Bool := do
|
||||
return (← get').steps >= (← getConfig).ringSteps
|
||||
|
||||
def checkMaxDegree (p : Poly) : GoalM Bool := do
|
||||
if p.degree >= (← getConfig).ringMaxDegree then
|
||||
unless (← get').reportedMaxDegreeIssue do
|
||||
modify' fun s => { s with reportedMaxDegreeIssue := true }
|
||||
reportIssue! "ring polynomial degree {p.degree} exceeds threshold `(ringMaxDegree := {p.degree})`"
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
def incSteps (n : Nat := 1) : GoalM Unit := do
|
||||
modify' fun s => { s with steps := s.steps + n }
|
||||
|
||||
|
|
|
|||
|
|
@ -135,4 +135,54 @@ def _root_.Lean.Grind.CommRing.Poly.findInvNumeralVar? (p : Poly) : RingM (Optio
|
|||
let some r ← m.findInvNumeralVar? | p.findInvNumeralVar?
|
||||
return some r
|
||||
|
||||
/--
|
||||
Result of simplifying a polynomial `p₁` using a polynomial `p₂`.
|
||||
|
||||
The simplification rewrites the first monomial of `p₁` that can be divided
|
||||
by the leading monomial of `p₂`.
|
||||
-/
|
||||
structure SimpResult where
|
||||
/-- The resulting simplified polynomial after rewriting. -/
|
||||
p : Poly := .num 0
|
||||
/-- The integer coefficient multiplied with polynomial `p₁` in the rewriting step. -/
|
||||
k₁ : Int := 0
|
||||
/-- The integer coefficient multiplied with polynomial `p₂` during rewriting. -/
|
||||
k₂ : Int := 0
|
||||
/-- The monomial factor applied to polynomial `p₂`. -/
|
||||
m₂ : Mon := .unit
|
||||
|
||||
/--
|
||||
Simplifies polynomial `p₁` using polynomial `p₂` by rewriting.
|
||||
|
||||
This function attempts to rewrite `p₁` by eliminating the first occurrence of
|
||||
the leading monomial of `p₂`.
|
||||
-/
|
||||
def _root_.Lean.Grind.CommRing.Poly.simpM? (p₁ p₂ : Poly) : RingM (Option SimpResult) := do
|
||||
match p₂ with
|
||||
| .add k₂' m₂ p₂ =>
|
||||
let rec go? (p₁ : Poly) : RingM (Option SimpResult) := do
|
||||
match p₁ with
|
||||
| .add k₁' m₁ p₁ =>
|
||||
if m₂.divides m₁ then
|
||||
let m₂ := m₁.div m₂
|
||||
let g := Nat.gcd k₁'.natAbs k₂'.natAbs
|
||||
let k₁ := k₂'/g
|
||||
let k₂ := -k₁'/g
|
||||
let p ← (← p₂.mulMonM k₂ m₂).combineM (← p₁.mulConstM k₁)
|
||||
return some { p, k₁, k₂, m₂ }
|
||||
else if let some r ← go? p₁ then
|
||||
if let some char ← nonzeroChar? then
|
||||
let k := (k₁'*r.k₁) % char
|
||||
if k == 0 then
|
||||
return some r
|
||||
else
|
||||
return some { r with p := .add k m₁ r.p }
|
||||
else
|
||||
return some { r with p := .add (k₁'*r.k₁) m₁ r.p }
|
||||
else
|
||||
return none
|
||||
| .num _ => return none
|
||||
go? p₁
|
||||
| _ => return none
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
|
|
|
|||
|
|
@ -311,6 +311,8 @@ structure State where
|
|||
`ncstypeIdOf[type]` is `some id`, then `id < ncSemirings.size`. -/
|
||||
ncstypeIdOf : PHashMap ExprPtr (Option Nat) := {}
|
||||
steps := 0
|
||||
/-- `true` if solver has already reported max degree issue. -/
|
||||
reportedMaxDegreeIssue : Bool := false
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize ringExt : SolverExtension State ← registerSolverExtension (return {})
|
||||
|
|
|
|||
9
tests/elab/grind_ring_degree_explosion.lean
Normal file
9
tests/elab/grind_ring_degree_explosion.lean
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
set_option warn.sorry false
|
||||
|
||||
/-!
|
||||
`grind` must fail quickly on problems containing high degree polynomials
|
||||
-/
|
||||
|
||||
theorem explosion (r p t3 t19 : Nat) : t19 % p = r ^ (2 ^ 250 - 1) % p ∧ t3 % p = r ^ 11 % p := by
|
||||
fail_if_success grind
|
||||
sorry
|
||||
|
|
@ -44,22 +44,3 @@ example : check_spoly (2*x + 3) (3*z + 1) (9*z - 2*x) := by native_decide
|
|||
example : check_spoly (2*y^2 - x + 1) (2*x*y - 1 + y) (-x^2 + y + x - y^2) := by native_decide
|
||||
example : check_spoly (2*y^2 - x + 1) (4*x*y - 1 + y) (-2*x^2 + y + 2*x - y^2) := by native_decide
|
||||
example : check_spoly (6*y^2 - x + 1) (4*x*y - 1 + y) (-2*x^2 + 3*y + 2*x - 3*y^2) := by native_decide
|
||||
|
||||
def simp? (p₁ p₂ : Poly) : Option Poly :=
|
||||
(·.p) <$> p₁.simp? p₂
|
||||
|
||||
partial def simp' (p₁ p₂ : Poly) : Poly :=
|
||||
if let some r := p₁.simp? p₂ then
|
||||
assert! r.p == (p₂.mulMon r.k₂ r.m₂).combine (p₁.mulConst r.k₁)
|
||||
simp' r.p p₂
|
||||
else
|
||||
p₁
|
||||
|
||||
def check_simp' (e₁ e₂ r : Expr) : Bool :=
|
||||
r.toPoly == simp' e₁.toPoly e₂.toPoly
|
||||
|
||||
example : check_simp' (x^2*y - 1) (x*y - y) (y - 1) := by native_decide
|
||||
example : check_simp' (x^2 + x + 1) (2*x + 1) 3 := by native_decide
|
||||
example : check_simp' (3*x^2 + x + y + 1) (2*x + 1) (4*y + 5) := by native_decide
|
||||
example : check_simp' (3*x^2 + x + y + 1) (2*x + y) (3*y^2 + 2*y + 4) := by native_decide
|
||||
example : check_simp' (z^4 + w^3 + x^2 + x + 1) (2*x + 1) (4*z^4 + 4*w^3 + 3) := by native_decide
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue