feat: add grind (ringSteps := <num>) (#8131)

This PR adds a configuration option that controls the maximum number of
steps the commutative-ring procedure in `grind` performs.
This commit is contained in:
Leonardo de Moura 2025-04-27 10:46:02 -07:00 committed by GitHub
parent 36ed58351d
commit d73557321b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 23 additions and 3 deletions

View file

@ -118,6 +118,7 @@ structure Config where
When `true` (default: `false`), uses procedure for handling equalities over commutative rings.
-/
ring := false
ringSteps := 10000
deriving Inhabited, BEq
end Lean.Grind

View file

@ -71,6 +71,7 @@ def _root_.Lean.Grind.CommRing.Poly.findSimp? (p : Poly) : RingM (Option EqCnstr
/-- Simplifies `d.p` using `c`, and returns an extended polynomial derivation. -/
def PolyDerivation.simplify1 (d : PolyDerivation) (c : EqCnstr) : RingM (Option PolyDerivation) := do
let some r := d.p.simp? c.p (← nonzeroChar?) | return none
incSteps
trace_goal[grind.ring.simp] "{← r.p.denoteExpr}"
return some <| .step r.p r.k₁ d r.k₂ r.m₂ c
@ -78,7 +79,7 @@ def PolyDerivation.simplify1 (d : PolyDerivation) (c : EqCnstr) : RingM (Option
def PolyDerivation.simplifyWith (d : PolyDerivation) (c : EqCnstr) : RingM PolyDerivation := do
let mut d := d
repeat
checkSystem "ring"
if (← checkMaxSteps) then return d
let some r ← d.simplify1 c | return d
trace_goal[grind.debug.ring.simp] "simplifying{indentD (← d.denoteExpr)}\nwith{indentD (← c.denoteExpr)}"
d := r
@ -88,6 +89,7 @@ def PolyDerivation.simplifyWith (d : PolyDerivation) (c : EqCnstr) : RingM PolyD
def PolyDerivation.simplify (d : PolyDerivation) : RingM PolyDerivation := do
let mut d := d
repeat
if (← checkMaxSteps) then return d
let some c ← d.p.findSimp? |
trace_goal[grind.debug.ring.simp] "simplified{indentD (← d.denoteExpr)}"
return d
@ -101,6 +103,7 @@ def EqCnstr.simplify1 (c₁ c₂ : EqCnstr) : RingM (Option EqCnstr) := do
p := r.p
h := .simp r.k₁ c₁ r.k₂ r.m₂ c₂
}
incSteps
trace_goal[grind.ring.simp] "{← c.p.denoteExpr}"
return some c
@ -108,7 +111,7 @@ def EqCnstr.simplify1 (c₁ c₂ : EqCnstr) : RingM (Option EqCnstr) := do
def EqCnstr.simplifyWith (c c' : EqCnstr) : RingM EqCnstr := do
let mut c := c
repeat
checkSystem "ring"
if (← checkMaxSteps) then return c
let some r ← c.simplify1 c' | return c
trace_goal[grind.debug.ring.simp] "simplifying{indentD (← c.denoteExpr)}\nwith{indentD (← c'.denoteExpr)}"
c := r
@ -118,6 +121,7 @@ def EqCnstr.simplifyWith (c c' : EqCnstr) : RingM EqCnstr := do
def EqCnstr.simplify (c : EqCnstr) : RingM EqCnstr := do
let mut c := c
repeat
if (← checkMaxSteps) then return c
let some c' ← c.p.findSimp? |
trace_goal[grind.debug.ring.simp] "simplified{indentD (← c.denoteExpr)}"
return c
@ -174,10 +178,12 @@ def EqCnstr.simplifyBasis (c : EqCnstr) : RingM Unit := do
addToBasisCore c'
def EqCnstr.addToQueue (c : EqCnstr) : RingM Unit := do
if (← checkMaxSteps) then return ()
trace_goal[grind.ring.assert.queue] "{← c.denoteExpr}"
modifyRing fun s => { s with queue := s.queue.insert c }
def EqCnstr.superposeWith (c : EqCnstr) : RingM Unit := do
if (← checkMaxSteps) then return ()
let .add _ m _ := c.p | return ()
go m
where
@ -310,6 +316,7 @@ private def propagateEqs : RingM Unit := do
-/
let mut map : PropagateEqMap := {}
for (a, ra) in (← getRing).denote do
if (← checkMaxSteps) then return ()
let a := a.expr
let d : PolyDerivation := .input (← ra.toPolyM)
let d ← d.simplify
@ -334,12 +341,14 @@ def checkRing : RingM Bool := do
trace_goal[grind.debug.ring.check] "{← c.denoteExpr}"
c.addToBasis
if (← isInconsistent) then return true
if (← checkMaxSteps) then return true
checkDiseqs
propagateEqs
modifyRing fun s => { s with recheck := false }
return true
def check : GoalM Bool := do
if (← checkMaxSteps) then return false
let mut progress := false
for ringId in [:(← get').rings.size] do
let r ← RingM.run ringId checkRing

View file

@ -188,6 +188,7 @@ structure State where
typeIdOf : PHashMap ENodeKey (Option Nat) := {}
/- Mapping from expressions/terms to their ring ids. -/
exprToRingId : PHashMap ENodeKey Nat := {}
steps := 0
deriving Inhabited
end Lean.Meta.Grind.Arith.CommRing

View file

@ -15,6 +15,12 @@ def get' : GoalM State := do
@[inline] def modify' (f : State → State) : GoalM Unit := do
modify fun s => { s with arith.ring := f s.arith.ring }
def checkMaxSteps : GoalM Bool := do
return (← get').steps >= (← getConfig).ringSteps
def incSteps : GoalM Unit := do
modify' fun s => { s with steps := s.steps + 1 }
/-- We don't want to keep carrying the `RingId` around. -/
abbrev RingM := ReaderT Nat GoalM
@ -113,6 +119,7 @@ def isQueueEmpty : RingM Bool :=
def getNext? : RingM (Option EqCnstr) := do
let some c := (← getRing).queue.min | return none
modifyRing fun s => { s with queue := s.queue.erase c }
incSteps
return some c
end Lean.Meta.Grind.Arith.CommRing

View file

@ -79,7 +79,7 @@ def mbtc (ctx : MBTC.Context) : GoalM Bool := do
if candidates.isEmpty then
return false
if (← get).split.num > (← getConfig).splits then
reportIssue "skipping `mbtc`, maximum number of splits has been reached `(splits := {(← getConfig).splits})`"
reportIssue! "skipping `mbtc`, maximum number of splits has been reached `(splits := {(← getConfig).splits})`"
return false
let result := candidates.toArray.qsort fun c₁ c₂ => c₁.lt c₂
let result ← result.filterMapM fun info => do

View file

@ -153,6 +153,8 @@ private def ppThresholds (c : Grind.Config) : M Unit := do
msgs := msgs.push <| .trace { cls := `limit } m!"maximum number of case-splits has been reached, threshold: `(splits := {c.splits})`" #[]
if maxGen ≥ c.gen then
msgs := msgs.push <| .trace { cls := `limit } m!"maximum term generation has been reached, threshold: `(gen := {c.gen})`" #[]
if goal.arith.ring.steps ≥ c.ringSteps then
msgs := msgs.push <| .trace { cls := `limit } m!"maximum number of ring steps has been reached, threshold: `(ringSteps := {c.ringSteps})`" #[]
unless msgs.isEmpty do
pushMsg <| .trace { cls := `limits } "Thresholds reached" msgs