fix: ring module in grind (#8713)
This PR fixes a bug in the commutative ring module used in `grind`. It was missing simplification opportunities.
This commit is contained in:
parent
0b2884bfa3
commit
2a63b392dd
9 changed files with 64 additions and 57 deletions
|
|
@ -38,5 +38,6 @@ builtin_initialize registerTraceClass `grind.debug.ring.proof
|
|||
builtin_initialize registerTraceClass `grind.debug.ring.check
|
||||
builtin_initialize registerTraceClass `grind.debug.ring.impEq
|
||||
builtin_initialize registerTraceClass `grind.debug.ring.simpBasis
|
||||
builtin_initialize registerTraceClass `grind.debug.ring.basis
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -48,15 +48,11 @@ then the leading coefficient of the equation must also divide `k`
|
|||
def _root_.Lean.Grind.CommRing.Mon.findSimp? (k : Int) (m : Mon) : RingM (Option EqCnstr) := do
|
||||
let checkCoeff ← checkCoeffDvd
|
||||
let noZeroDiv ← noZeroDivisors
|
||||
let rec go : Mon → RingM (Option EqCnstr)
|
||||
| .unit => return none
|
||||
| .mult pw m' => do
|
||||
for c in (← getRing).varToBasis[pw.x]! do
|
||||
if !checkCoeff || noZeroDiv || (c.p.lc ∣ k) then
|
||||
if c.p.divides m then
|
||||
return some c
|
||||
go m'
|
||||
go m
|
||||
for c in (← getRing).basis do
|
||||
if !checkCoeff || noZeroDiv || (c.p.lc ∣ k) then
|
||||
if c.p.divides m then
|
||||
return some c
|
||||
return none
|
||||
|
||||
/--
|
||||
Returns `some c`, where `c` is an equation from the basis whose leading monomial divides some
|
||||
|
|
@ -129,6 +125,7 @@ def EqCnstr.checkConstant (c : EqCnstr) : RingM Bool := do
|
|||
c.setUnsat
|
||||
else
|
||||
-- Remark: we currently don't do anything if the characteristic is not known.
|
||||
-- TODO: if `k.natAbs` is `1`, we could set all terms of this ring `0`.
|
||||
trace_goal[grind.ring.assert.discard] "{← c.denoteExpr}"
|
||||
return true
|
||||
|
||||
|
|
@ -153,10 +150,9 @@ private def addSorted (c : EqCnstr) : List EqCnstr → List EqCnstr
|
|||
c' :: addSorted c cs
|
||||
|
||||
def addToBasisCore (c : EqCnstr) : RingM Unit := do
|
||||
let .add _ m _ := c.p | return ()
|
||||
let .mult pw _ := m | return ()
|
||||
trace[grind.debug.ring.basis] "{← c.denoteExpr}"
|
||||
modifyRing fun s => { s with
|
||||
varToBasis := s.varToBasis.modify pw.x (addSorted c)
|
||||
basis := addSorted c s.basis
|
||||
recheck := true
|
||||
}
|
||||
|
||||
|
|
@ -168,18 +164,12 @@ def EqCnstr.addToQueue (c : EqCnstr) : RingM Unit := do
|
|||
def EqCnstr.superposeWith (c : EqCnstr) : RingM Unit := do
|
||||
if (← checkMaxSteps) then return ()
|
||||
let .add _ m _ := c.p | return ()
|
||||
go m
|
||||
where
|
||||
go : Mon → RingM Unit
|
||||
| .unit => return ()
|
||||
| .mult pw m => do
|
||||
let x := pw.x
|
||||
let cs := (← getRing).varToBasis[x]!
|
||||
for c' in cs do
|
||||
let r ← c.p.spolM c'.p
|
||||
trace_goal[grind.ring.superpose] "{← c.denoteExpr}\nwith: {← c'.denoteExpr}\nresult: {← r.spol.denoteExpr} = 0"
|
||||
addToQueue (← mkEqCnstr r.spol <| .superpose r.k₁ r.m₁ c r.k₂ r.m₂ c')
|
||||
go m
|
||||
for c' in (← getRing).basis do
|
||||
let .add _ m' _ := c'.p | pure ()
|
||||
if m.sharesVar m' then
|
||||
let r ← c.p.spolM c'.p
|
||||
trace_goal[grind.ring.superpose] "{← c.denoteExpr}\nwith: {← c'.denoteExpr}\nresult: {← r.spol.denoteExpr} = 0"
|
||||
addToQueue (← mkEqCnstr r.spol <| .superpose r.k₁ r.m₁ c r.k₂ r.m₂ c')
|
||||
|
||||
/--
|
||||
Tries to convert the leading monomial into a monic one.
|
||||
|
|
@ -215,25 +205,23 @@ def EqCnstr.toMonic (c : EqCnstr) : RingM EqCnstr := do
|
|||
def EqCnstr.simplifyBasis (c : EqCnstr) : RingM Unit := do
|
||||
trace[grind.debug.ring.simpBasis] "using: {← c.denoteExpr}"
|
||||
let .add _ m _ := c.p | return ()
|
||||
let rec go (m' : Mon) : RingM Unit := do
|
||||
match m' with
|
||||
| .unit => return ()
|
||||
| .mult pw m' => goVar m pw.x; go m'
|
||||
go m
|
||||
where
|
||||
goVar (m : Mon) (x : Var) : RingM Unit := do
|
||||
let cs := (← getRing).varToBasis[x]!
|
||||
if cs.isEmpty then return ()
|
||||
modifyRing fun s => { s with varToBasis := s.varToBasis.set x {} }
|
||||
for c' in cs do
|
||||
trace[grind.debug.ring.simpBasis] "target: {← c'.denoteExpr}"
|
||||
let .add _ m' _ := c'.p | pure ()
|
||||
if m.divides m' then
|
||||
let c'' ← c'.simplifyWithExhaustively c
|
||||
trace[grind.debug.ring.simpBasis] "simplified: {← c''.denoteExpr}"
|
||||
addToQueue c''
|
||||
else
|
||||
addToBasisCore c'
|
||||
let rec go (basis : List EqCnstr) (acc : List EqCnstr) : RingM (List EqCnstr) := do
|
||||
match basis with
|
||||
| [] => return acc.reverse
|
||||
| c' :: basis =>
|
||||
match c'.p with
|
||||
| .add _ m' _ =>
|
||||
if m.divides m' then
|
||||
let c'' ← c'.simplifyWithExhaustively c
|
||||
trace[grind.debug.ring.simpBasis] "simplified: {← c''.denoteExpr}"
|
||||
unless (← checkConstant c'') do
|
||||
addToQueue c''
|
||||
go basis acc
|
||||
else
|
||||
go basis (c' :: acc)
|
||||
| _ => go basis (c' :: acc)
|
||||
let basis ← go (← getRing).basis []
|
||||
modifyRing fun s => { s with basis }
|
||||
|
||||
def EqCnstr.addToBasisAfterSimp (c : EqCnstr) : RingM Unit := do
|
||||
let c ← c.toMonic
|
||||
|
|
|
|||
|
|
@ -29,12 +29,8 @@ private def checkPoly (p : Poly) : RingM Unit := do
|
|||
|
||||
private def checkBasis : RingM Unit := do
|
||||
let mut x := 0
|
||||
for cs in (← getRing).varToBasis do
|
||||
for c in cs do
|
||||
checkPoly c.p
|
||||
let .add _ m _ := c.p | unreachable!
|
||||
let .mult pw _ := m | unreachable!
|
||||
assert! pw.x == x
|
||||
for c in (← getRing).basis do
|
||||
checkPoly c.p
|
||||
x := x + 1
|
||||
|
||||
private def checkQueue : RingM Unit := do
|
||||
|
|
|
|||
|
|
@ -24,9 +24,8 @@ private def push (msgs : Array MessageData) (msg? : Option MessageData) : Array
|
|||
|
||||
def ppBasis? : ReaderT Ring MetaM (Option MessageData) := do
|
||||
let mut basis := #[]
|
||||
for cs in (← getRing).varToBasis do
|
||||
for c in cs do
|
||||
basis := basis.push (toTraceElem (← c.denoteExpr))
|
||||
for c in (← getRing).basis do
|
||||
basis := basis.push (toTraceElem (← c.denoteExpr))
|
||||
return toOption `basis "Basis" basis
|
||||
|
||||
def ppDiseqs? : ReaderT Ring MetaM (Option MessageData) := do
|
||||
|
|
|
|||
|
|
@ -7,6 +7,16 @@ prelude
|
|||
import Init.Grind.CommRing.Poly
|
||||
namespace Lean.Grind.CommRing
|
||||
|
||||
/-- `sharesVar m₁ m₂` returns `true` if `m₁` and `m₂` shares at least one variable. -/
|
||||
def Mon.sharesVar : Mon → Mon → Bool
|
||||
| .unit, _ => false
|
||||
| _, .unit => false
|
||||
| .mult pw₁ m₁, .mult pw₂ m₂ =>
|
||||
match compare pw₁.x pw₂.x with
|
||||
| .eq => true
|
||||
| .lt => sharesVar m₁ (.mult pw₂ m₂)
|
||||
| .gt => sharesVar (.mult pw₁ m₁) m₂
|
||||
|
||||
/-- `lcm m₁ m₂` returns the least common multiple of the given monomials. -/
|
||||
def Mon.lcm : Mon → Mon → Mon
|
||||
| .unit, m₂ => m₂
|
||||
|
|
|
|||
|
|
@ -165,10 +165,11 @@ structure Ring where
|
|||
/-- Equations to process. -/
|
||||
queue : Queue := {}
|
||||
/--
|
||||
Mapping from variables `x` to equations such that the smallest variable
|
||||
in the leading monomial is `x`.
|
||||
The basis is currently just a list. If this is a performance bottleneck, we should use
|
||||
a better data-structure. For examples, we could use a simple indexing for the linear case
|
||||
where we map variable in the leading monomial to `EqCnstr`.
|
||||
-/
|
||||
varToBasis : PArray (List EqCnstr) := {}
|
||||
basis : List EqCnstr := {}
|
||||
/-- Disequalities. -/
|
||||
-- TODO: add indexing
|
||||
diseqs : PArray DiseqCnstr := {}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ def mkVar (e : Expr) : RingM Var := do
|
|||
modifyRing fun s => { s with
|
||||
vars := s.vars.push e
|
||||
varMap := s.varMap.insert { expr := e } var
|
||||
varToBasis := s.varToBasis.push []
|
||||
}
|
||||
setTermRingId e
|
||||
markAsCommRingTerm e
|
||||
|
|
|
|||
|
|
@ -66,3 +66,16 @@ set_option trace.grind.ring.assert.queue true in
|
|||
example (x y : Int) : x + 16*y^2 - 7*x^2 = 0 → False := by
|
||||
fail_if_success grind
|
||||
sorry
|
||||
|
||||
/--
|
||||
trace: [grind.debug.ring.basis] a ^ 2 * b + -1 = 0
|
||||
[grind.debug.ring.basis] a * b ^ 2 + -1 * b = 0
|
||||
[grind.debug.ring.basis] a * b + -1 * b = 0
|
||||
[grind.debug.ring.basis] b + -1 = 0
|
||||
[grind.debug.ring.basis] a + -1 = 0
|
||||
-/
|
||||
#guard_msgs (drop error, trace) in
|
||||
set_option trace.grind.debug.ring.basis true in
|
||||
example [CommRing α] (a b c : α)
|
||||
: a^2*b = 1 → a*b^2 = b → False := by
|
||||
grind
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ example (a b c : Int) (f : Int → Nat)
|
|||
f (a^4 + b^4) + f (9 - c^4) ≠ 1 := by
|
||||
grind
|
||||
|
||||
example [CommRing α] (a b c : α) (f : α → Nat)
|
||||
example [CommRing α] [NoNatZeroDivisors α] (a b c : α) (f : α → Nat)
|
||||
: a + b + c = 3 →
|
||||
a^2 + b^2 + c^2 = 5 →
|
||||
a^3 + b^3 + c^3 = 7 →
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue