diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean index e3a4b2605e..689e0d58c8 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean @@ -45,25 +45,30 @@ structure PreNullCert where Thus, we need to track a denominator to justify the proof step `div`. -/ d : Int := 1 + deriving Inhabited def PreNullCert.unit (i : Nat) (n : Nat) : PreNullCert := let qs := Array.replicate n (.num 0) let qs := qs.set! i (.num 1) { qs } -def PreNullCert.mul (c : PreNullCert) (k : Int) (char? : Option Nat) : PreNullCert := - if k == 1 then c +def PreNullCert.div (c : PreNullCert) (k : Int) : RingM PreNullCert := do + return { c with d := c.d * k } + +def PreNullCert.mul (c : PreNullCert) (k : Int) : RingM PreNullCert := do + if k == 1 then + return c else let g := Int.gcd k c.d let k := k / g let d := c.d / g if k == 1 then - { c with d } + return { c with d } else - let qs := c.qs.map fun q => if q.isZero then q else q.mulConst' k char? - { qs, d } + let qs ← c.qs.mapM fun q => q.mulConstM k + return { qs, d } -def PreNullCert.combine (k₁ : Int) (m₁ : Mon) (c₁ : PreNullCert) (k₂ : Int) (m₂ : Mon) (c₂ : PreNullCert) (char? : Option Nat) : PreNullCert := Id.run do +def PreNullCert.combine (k₁ : Int) (m₁ : Mon) (c₁ : PreNullCert) (k₂ : Int) (m₂ : Mon) (c₂ : PreNullCert) : RingM PreNullCert := do let d₁ := c₁.d let d₂ := c₂.d let k₁_d₂ := k₁*d₂ @@ -79,17 +84,17 @@ def PreNullCert.combine (k₁ : Int) (m₁ : Mon) (c₁ : PreNullCert) (k₂ : I let mut qs : Vector Poly n := Vector.replicate n (.num 0) for h : i in [:n] do if h₁ : i < qs₁.size then - let q₁ := qs₁[i].mulMon' k₁ m₁ char? + let q₁ ← qs₁[i].mulMonM k₁ m₁ if h₂ : i < qs₂.size then - let q₂ := qs₂[i].mulMon' k₂ m₂ char? - qs := qs.set i (q₁.combine' q₂ char?) + let q₂ ← qs₂[i].mulMonM k₂ m₂ + qs := qs.set i (← q₁.combineM q₂) else qs := qs.set i q₁ else have : i < n := h.upper have : qs₁.size = n ∨ qs₂.size = n := by simp +zetaDelta [Nat.max_def]; split <;> simp [*] have : i < qs₂.size := by omega - let q₂ := qs₂[i].mulMon' k₂ m₂ char? + let q₂ ← qs₂[i].mulMonM k₂ m₂ qs := qs.set i q₂ return { qs := qs.toArray, d } @@ -101,9 +106,57 @@ structure NullCertHypothesis where structure ProofM.State where /-- Mapping from `EqCnstr` to `PreNullCert` -/ cache : Std.HashMap UInt64 PreNullCert := {} - hypToId : Std.HashMap UInt64 Nat := {} hyps : Array NullCertHypothesis := #[] +abbrev ProofM := StateRefT ProofM.State RingM + +private abbrev caching (c : α) (k : ProofM PreNullCert) : ProofM PreNullCert := do + let addr := unsafe (ptrAddrUnsafe c).toUInt64 >>> 2 + if let some h := (← get).cache[addr]? then + return h + else + let h ← k + modify fun s => { s with cache := s.cache.insert addr h } + return h + +partial def EqCnstr.toPreNullCert (c : EqCnstr) : ProofM PreNullCert := caching c do + match c.h with + | .core a b lhs rhs => + let i := (← get).hyps.size + let h ← mkEqProof a b + modify fun s => { s with hyps := s.hyps.push { h, lhs, rhs } } + return PreNullCert.unit i (i+1) + | .superpose c₁ c₂ k₁ k₂ m₁ m₂ => (← c₁.toPreNullCert).combine k₁ m₁ k₂ m₂ (← c₂.toPreNullCert) + | .simp c₁ c₂ k₁ k₂ m => (← c₁.toPreNullCert).combine k₁ m k₂ .unit (← c₂.toPreNullCert) + | .mul k c => (← c.toPreNullCert).mul k + | .div k c => (← c.toPreNullCert).div k + +structure NullCertExt where + d : Int + qhs : Array (Poly × NullCertHypothesis) + +def EqCnstr.mkNullCertExt (c : EqCnstr) : RingM NullCertExt := do + let (nc, s) ← c.toPreNullCert.run {} + return { d := nc.d, qhs := nc.qs.zip s.hyps } + +def NullCertExt.toPoly (nc : NullCertExt) : RingM Poly := do + let mut p : Poly := .num 0 + for (q, h) in nc.qhs do + p ← p.combineM (← q.mulM (← (h.lhs.sub h.rhs).toPolyM)) + return p + +def NullCertExt.check (c : EqCnstr) (nc : NullCertExt) : RingM Bool := do + let p₁ := c.p.mulConst' nc.d (← nonzeroChar?) + let p₂ ← nc.toPoly + return p₁ == p₂ + +def setInconsistent (c : EqCnstr) : RingM Unit := do + trace_goal[grind.ring.assert.unsat] "{← c.denoteExpr}" + let nc ← c.mkNullCertExt + trace_goal[grind.ring.assert.unsat] "{nc.d}*({← c.p.denoteExpr}), {← (← nc.toPoly).denoteExpr}" + trace_goal[grind.ring.assert.unsat] "{nc.d}*({← c.p.denoteExpr}), {← nc.qhs.mapM fun (p, h) => return (← p.denoteExpr, ← h.lhs.denoteExpr, ← h.rhs.denoteExpr) }" + -- TODO + private def mkLemmaPrefix (declName declNameC : Name) : RingM Expr := do let ring ← getRing let ctx ← toContextExpr @@ -123,8 +176,5 @@ def setEqUnsat (k : Int) (a b : Expr) (ra rb : RingExpr) : RingM Unit := do h := mkApp h charInst closeGoal <| mkApp5 h (toExpr ra) (toExpr rb) (toExpr k) reflBoolTrue (← mkEqProof a b) -def setInconsistent (c : EqCnstr) : RingM Unit := do - trace_goal[grind.ring.assert.unsat] "{← c.denoteExpr}" - -- TODO end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean index 1ec4d0b046..d5bc0afbcc 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean @@ -26,7 +26,7 @@ structure EqCnstr where inductive EqCnstrProof where | core (a b : Expr) (ra rb : RingExpr) - | superpose (c₁ c₂ : EqCnstr) + | superpose (c₁ c₂ : EqCnstr) (k₁ k₂ : Int) (m₁ m₂ : Mon) | simp (c₁ c₂ : EqCnstr) (k₁ k₂ : Int) (m : Mon) | mul (k : Int) (e : EqCnstr) | div (k : Int) (e : EqCnstr) @@ -104,7 +104,7 @@ inductive SimpChain where ``` If we have a commutative ring where ``` - ∀ (k : Int) (a b : α), k ≠ 0 → (intCast k) * a = 0 → a = 0 + ∀ (k : Int) (a : α), k ≠ 0 → (intCast k) * a = 0 → a = 0 ``` grind can deduce that `x+y+z = 0` -/ diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean index a1295cda6c..4da0951f90 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ prelude import Lean.Meta.Tactic.Grind.Types +import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly namespace Lean.Meta.Grind.Arith.CommRing @@ -77,9 +78,18 @@ Converts the given ring expression into a multivariate polynomial. If the ring has a nonzero characteristic, it is used during normalization. -/ def _root_.Lean.Grind.CommRing.Expr.toPolyM (e : RingExpr) : RingM Poly := do - if let some c ← nonzeroChar? then - return e.toPolyC c - else - return e.toPoly + if let some c ← nonzeroChar? then return e.toPolyC c else return e.toPoly + +def _root_.Lean.Grind.CommRing.Poly.mulConstM (p : Poly) (k : Int) : RingM Poly := + return p.mulConst' k (← nonzeroChar?) + +def _root_.Lean.Grind.CommRing.Poly.mulMonM (p : Poly) (k : Int) (m : Mon) : RingM Poly := + return p.mulMon' k m (← nonzeroChar?) + +def _root_.Lean.Grind.CommRing.Poly.mulM (p₁ p₂ : Poly) : RingM Poly := do + if let some c ← nonzeroChar? then return p₁.mulC p₂ c else return p₁.mul p₂ + +def _root_.Lean.Grind.CommRing.Poly.combineM (p₁ p₂ : Poly) : RingM Poly := + return p₁.combine' p₂ (← nonzeroChar?) end Lean.Meta.Grind.Arith.CommRing