From 2a63b392ddad0bef1d42eb973fe50964644de236 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 10 Jun 2025 21:20:50 -0400 Subject: [PATCH] fix: ring module in `grind` (#8713) This PR fixes a bug in the commutative ring module used in `grind`. It was missing simplification opportunities. --- .../Meta/Tactic/Grind/Arith/CommRing.lean | 1 + .../Tactic/Grind/Arith/CommRing/EqCnstr.lean | 74 ++++++++----------- .../Meta/Tactic/Grind/Arith/CommRing/Inv.lean | 8 +- .../Meta/Tactic/Grind/Arith/CommRing/PP.lean | 5 +- .../Tactic/Grind/Arith/CommRing/Poly.lean | 10 +++ .../Tactic/Grind/Arith/CommRing/Types.lean | 7 +- .../Meta/Tactic/Grind/Arith/CommRing/Var.lean | 1 - tests/lean/run/grind_ring_1.lean | 13 ++++ tests/lean/run/grind_ring_2.lean | 2 +- 9 files changed, 64 insertions(+), 57 deletions(-) diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean index 886b1437b1..fd321ab099 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean @@ -38,5 +38,6 @@ builtin_initialize registerTraceClass `grind.debug.ring.proof builtin_initialize registerTraceClass `grind.debug.ring.check builtin_initialize registerTraceClass `grind.debug.ring.impEq builtin_initialize registerTraceClass `grind.debug.ring.simpBasis +builtin_initialize registerTraceClass `grind.debug.ring.basis end Lean diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean index 1aaac003ae..5a86a761f1 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean @@ -48,15 +48,11 @@ then the leading coefficient of the equation must also divide `k` def _root_.Lean.Grind.CommRing.Mon.findSimp? (k : Int) (m : Mon) : RingM (Option EqCnstr) := do let checkCoeff ← checkCoeffDvd let noZeroDiv ← noZeroDivisors - let rec go : Mon → RingM (Option EqCnstr) - | .unit => return none - | .mult pw m' => do - for c in (← getRing).varToBasis[pw.x]! do - if !checkCoeff || noZeroDiv || (c.p.lc ∣ k) then - if c.p.divides m then - return some c - go m' - go m + for c in (← getRing).basis do + if !checkCoeff || noZeroDiv || (c.p.lc ∣ k) then + if c.p.divides m then + return some c + return none /-- Returns `some c`, where `c` is an equation from the basis whose leading monomial divides some @@ -129,6 +125,7 @@ def EqCnstr.checkConstant (c : EqCnstr) : RingM Bool := do c.setUnsat else -- Remark: we currently don't do anything if the characteristic is not known. + -- TODO: if `k.natAbs` is `1`, we could set all terms of this ring `0`. trace_goal[grind.ring.assert.discard] "{← c.denoteExpr}" return true @@ -153,10 +150,9 @@ private def addSorted (c : EqCnstr) : List EqCnstr → List EqCnstr c' :: addSorted c cs def addToBasisCore (c : EqCnstr) : RingM Unit := do - let .add _ m _ := c.p | return () - let .mult pw _ := m | return () + trace[grind.debug.ring.basis] "{← c.denoteExpr}" modifyRing fun s => { s with - varToBasis := s.varToBasis.modify pw.x (addSorted c) + basis := addSorted c s.basis recheck := true } @@ -168,18 +164,12 @@ def EqCnstr.addToQueue (c : EqCnstr) : RingM Unit := do def EqCnstr.superposeWith (c : EqCnstr) : RingM Unit := do if (← checkMaxSteps) then return () let .add _ m _ := c.p | return () - go m -where - go : Mon → RingM Unit - | .unit => return () - | .mult pw m => do - let x := pw.x - let cs := (← getRing).varToBasis[x]! - for c' in cs do - let r ← c.p.spolM c'.p - trace_goal[grind.ring.superpose] "{← c.denoteExpr}\nwith: {← c'.denoteExpr}\nresult: {← r.spol.denoteExpr} = 0" - addToQueue (← mkEqCnstr r.spol <| .superpose r.k₁ r.m₁ c r.k₂ r.m₂ c') - go m + for c' in (← getRing).basis do + let .add _ m' _ := c'.p | pure () + if m.sharesVar m' then + let r ← c.p.spolM c'.p + trace_goal[grind.ring.superpose] "{← c.denoteExpr}\nwith: {← c'.denoteExpr}\nresult: {← r.spol.denoteExpr} = 0" + addToQueue (← mkEqCnstr r.spol <| .superpose r.k₁ r.m₁ c r.k₂ r.m₂ c') /-- Tries to convert the leading monomial into a monic one. @@ -215,25 +205,23 @@ def EqCnstr.toMonic (c : EqCnstr) : RingM EqCnstr := do def EqCnstr.simplifyBasis (c : EqCnstr) : RingM Unit := do trace[grind.debug.ring.simpBasis] "using: {← c.denoteExpr}" let .add _ m _ := c.p | return () - let rec go (m' : Mon) : RingM Unit := do - match m' with - | .unit => return () - | .mult pw m' => goVar m pw.x; go m' - go m -where - goVar (m : Mon) (x : Var) : RingM Unit := do - let cs := (← getRing).varToBasis[x]! - if cs.isEmpty then return () - modifyRing fun s => { s with varToBasis := s.varToBasis.set x {} } - for c' in cs do - trace[grind.debug.ring.simpBasis] "target: {← c'.denoteExpr}" - let .add _ m' _ := c'.p | pure () - if m.divides m' then - let c'' ← c'.simplifyWithExhaustively c - trace[grind.debug.ring.simpBasis] "simplified: {← c''.denoteExpr}" - addToQueue c'' - else - addToBasisCore c' + let rec go (basis : List EqCnstr) (acc : List EqCnstr) : RingM (List EqCnstr) := do + match basis with + | [] => return acc.reverse + | c' :: basis => + match c'.p with + | .add _ m' _ => + if m.divides m' then + let c'' ← c'.simplifyWithExhaustively c + trace[grind.debug.ring.simpBasis] "simplified: {← c''.denoteExpr}" + unless (← checkConstant c'') do + addToQueue c'' + go basis acc + else + go basis (c' :: acc) + | _ => go basis (c' :: acc) + let basis ← go (← getRing).basis [] + modifyRing fun s => { s with basis } def EqCnstr.addToBasisAfterSimp (c : EqCnstr) : RingM Unit := do let c ← c.toMonic diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Inv.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Inv.lean index b69b87274d..209d366ab2 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Inv.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Inv.lean @@ -29,12 +29,8 @@ private def checkPoly (p : Poly) : RingM Unit := do private def checkBasis : RingM Unit := do let mut x := 0 - for cs in (← getRing).varToBasis do - for c in cs do - checkPoly c.p - let .add _ m _ := c.p | unreachable! - let .mult pw _ := m | unreachable! - assert! pw.x == x + for c in (← getRing).basis do + checkPoly c.p x := x + 1 private def checkQueue : RingM Unit := do diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/PP.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/PP.lean index 4a3e7b50db..5f786c0189 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/PP.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/PP.lean @@ -24,9 +24,8 @@ private def push (msgs : Array MessageData) (msg? : Option MessageData) : Array def ppBasis? : ReaderT Ring MetaM (Option MessageData) := do let mut basis := #[] - for cs in (← getRing).varToBasis do - for c in cs do - basis := basis.push (toTraceElem (← c.denoteExpr)) + for c in (← getRing).basis do + basis := basis.push (toTraceElem (← c.denoteExpr)) return toOption `basis "Basis" basis def ppDiseqs? : ReaderT Ring MetaM (Option MessageData) := do diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean index 0d3d8f1eec..1b97f815d9 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean @@ -7,6 +7,16 @@ prelude import Init.Grind.CommRing.Poly namespace Lean.Grind.CommRing +/-- `sharesVar m₁ m₂` returns `true` if `m₁` and `m₂` shares at least one variable. -/ +def Mon.sharesVar : Mon → Mon → Bool + | .unit, _ => false + | _, .unit => false + | .mult pw₁ m₁, .mult pw₂ m₂ => + match compare pw₁.x pw₂.x with + | .eq => true + | .lt => sharesVar m₁ (.mult pw₂ m₂) + | .gt => sharesVar (.mult pw₁ m₁) m₂ + /-- `lcm m₁ m₂` returns the least common multiple of the given monomials. -/ def Mon.lcm : Mon → Mon → Mon | .unit, m₂ => m₂ diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean index e267ec970b..e63b78517f 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean @@ -165,10 +165,11 @@ structure Ring where /-- Equations to process. -/ queue : Queue := {} /-- - Mapping from variables `x` to equations such that the smallest variable - in the leading monomial is `x`. + The basis is currently just a list. If this is a performance bottleneck, we should use + a better data-structure. For examples, we could use a simple indexing for the linear case + where we map variable in the leading monomial to `EqCnstr`. -/ - varToBasis : PArray (List EqCnstr) := {} + basis : List EqCnstr := {} /-- Disequalities. -/ -- TODO: add indexing diseqs : PArray DiseqCnstr := {} diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Var.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Var.lean index 46fe59d8ac..17b1946ea0 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Var.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Var.lean @@ -16,7 +16,6 @@ def mkVar (e : Expr) : RingM Var := do modifyRing fun s => { s with vars := s.vars.push e varMap := s.varMap.insert { expr := e } var - varToBasis := s.varToBasis.push [] } setTermRingId e markAsCommRingTerm e diff --git a/tests/lean/run/grind_ring_1.lean b/tests/lean/run/grind_ring_1.lean index fef54b1d1c..bef6809046 100644 --- a/tests/lean/run/grind_ring_1.lean +++ b/tests/lean/run/grind_ring_1.lean @@ -66,3 +66,16 @@ set_option trace.grind.ring.assert.queue true in example (x y : Int) : x + 16*y^2 - 7*x^2 = 0 → False := by fail_if_success grind sorry + +/-- +trace: [grind.debug.ring.basis] a ^ 2 * b + -1 = 0 +[grind.debug.ring.basis] a * b ^ 2 + -1 * b = 0 +[grind.debug.ring.basis] a * b + -1 * b = 0 +[grind.debug.ring.basis] b + -1 = 0 +[grind.debug.ring.basis] a + -1 = 0 +-/ +#guard_msgs (drop error, trace) in +set_option trace.grind.debug.ring.basis true in +example [CommRing α] (a b c : α) + : a^2*b = 1 → a*b^2 = b → False := by + grind diff --git a/tests/lean/run/grind_ring_2.lean b/tests/lean/run/grind_ring_2.lean index beb6caf782..7850fba020 100644 --- a/tests/lean/run/grind_ring_2.lean +++ b/tests/lean/run/grind_ring_2.lean @@ -126,7 +126,7 @@ example (a b c : Int) (f : Int → Nat) f (a^4 + b^4) + f (9 - c^4) ≠ 1 := by grind -example [CommRing α] (a b c : α) (f : α → Nat) +example [CommRing α] [NoNatZeroDivisors α] (a b c : α) (f : α → Nat) : a + b + c = 3 → a^2 + b^2 + c^2 = 5 → a^3 + b^3 + c^3 = 7 →