refactor: grind ring as solver extension (#10308)
This PR uses the new solver extension framework to implement `grind ring`.
This commit is contained in:
parent
79051fb5c0
commit
c34ea82bc2
15 changed files with 36 additions and 109 deletions
|
|
@ -21,11 +21,8 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
|||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Inv
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.PP
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean
|
||||
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
builtin_initialize registerTraceClass `grind.ring
|
||||
builtin_initialize registerTraceClass `grind.ring.internalize
|
||||
builtin_initialize registerTraceClass `grind.ring.assert
|
||||
|
|
@ -46,4 +43,12 @@ builtin_initialize registerTraceClass `grind.debug.ring.simpBasis
|
|||
builtin_initialize registerTraceClass `grind.debug.ring.basis
|
||||
builtin_initialize registerTraceClass `grind.debug.ring.rabinowitsch
|
||||
|
||||
end Lean
|
||||
builtin_initialize
|
||||
ringExt.setMethods
|
||||
(internalize := CommRing.internalize)
|
||||
(newEq := CommRing.processNewEq)
|
||||
(newDiseq := CommRing.processNewDiseq)
|
||||
(check := CommRing.check)
|
||||
(checkInv := CommRing.checkInvariants)
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
|
|
|
|||
|
|
@ -331,8 +331,7 @@ def addNewDiseq (c : DiseqCnstr) : RingM Unit := do
|
|||
trace[grind.ring.assert.store] "{← c.denoteExpr}"
|
||||
saveDiseq c
|
||||
|
||||
@[export lean_process_ring_eq]
|
||||
def processNewEqImpl (a b : Expr) : GoalM Unit := do
|
||||
def processNewEq (a b : Expr) : GoalM Unit := do
|
||||
if isSameExpr a b then return () -- TODO: check why this is needed
|
||||
if let some ringId ← inSameRing? a b then RingM.run ringId do
|
||||
trace_goal[grind.ring.assert] "{← mkEq a b}"
|
||||
|
|
@ -382,8 +381,7 @@ private def diseqZeroToEq (a b : Expr) : RingM Unit := do
|
|||
trace[grind.debug.ring.rabinowitsch] "{lhs}"
|
||||
pushEq lhs (← getOne) <| mkApp4 (mkConst ``Grind.CommRing.diseq0_to_eq [ring.u]) ring.type fieldInst a (← mkDiseqProof a b)
|
||||
|
||||
@[export lean_process_ring_diseq]
|
||||
def processNewDiseqImpl (a b : Expr) : GoalM Unit := do
|
||||
def processNewDiseq (a b : Expr) : GoalM Unit := do
|
||||
if let some ringId ← inSameRing? a b then RingM.run ringId do
|
||||
trace_goal[grind.ring.assert] "{mkNot (← mkEq a b)}"
|
||||
let some ra ← toRingExpr? a | return ()
|
||||
|
|
|
|||
|
|
@ -8,9 +8,13 @@ prelude
|
|||
public import Lean.Meta.Tactic.Grind.Types
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
builtin_initialize ringExt : SolverExtension State ← registerSolverExtension (return {})
|
||||
|
||||
def get' : GoalM State := do
|
||||
return (← get).arith.ring
|
||||
ringExt.getState
|
||||
|
||||
@[inline] def modify' (f : State → State) : GoalM Unit := do
|
||||
modify fun s => { s with arith.ring := f s.arith.ring }
|
||||
ringExt.modifyState f
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
|
|||
let some re ← reify? e | return ()
|
||||
trace_goal[grind.ring.internalize] "[{ringId}]: {e}"
|
||||
setTermRingId e
|
||||
markAsCommRingTerm e
|
||||
ringExt.markTerm e
|
||||
modifyRing fun s => { s with
|
||||
denote := s.denote.insert { expr := e } re
|
||||
denoteEntries := s.denoteEntries.push (e, re)
|
||||
|
|
@ -142,7 +142,7 @@ def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
|
|||
let some re ← sreify? e | return ()
|
||||
trace_goal[grind.ring.internalize] "semiring [{semiringId}]: {e}"
|
||||
setTermSemiringId e
|
||||
markAsCommRingTerm e
|
||||
ringExt.markTerm e
|
||||
modifySemiring fun s => { s with denote := s.denote.insert { expr := e } re }
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
|||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.GetSet
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
|
|
@ -49,7 +50,7 @@ private def ppRing? : M (Option MessageData) := do
|
|||
|
||||
def pp? (goal : Goal) : MetaM (Option MessageData) := do
|
||||
let mut msgs := #[]
|
||||
for ring in goal.arith.ring.rings do
|
||||
for ring in (← ringExt.getStateCore goal).rings do
|
||||
let some msg ← ppRing? |>.run' ring | pure ()
|
||||
msgs := msgs.push msg
|
||||
if msgs.isEmpty then
|
||||
|
|
@ -59,4 +60,11 @@ def pp? (goal : Goal) : MetaM (Option MessageData) := do
|
|||
else
|
||||
return some (.trace { cls := `ring } "Rings" msgs)
|
||||
|
||||
def addThresholdMessage (goal : Goal) (c : Grind.Config) (msgs : Array MessageData) : IO (Array MessageData) := do
|
||||
let s ← ringExt.getStateCore goal
|
||||
if s.steps ≥ c.ringSteps then
|
||||
return msgs.push <| .trace { cls := `limit } m!"maximum number of ring steps has been reached, threshold: `(ringSteps := {c.ringSteps})`" #[]
|
||||
else
|
||||
return msgs
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ def mkVar (e : Expr) : RingM Var := do
|
|||
varMap := s.varMap.insert { expr := e } var
|
||||
}
|
||||
setTermRingId e
|
||||
markAsCommRingTerm e
|
||||
ringExt.markTerm e
|
||||
return var
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ def mkSVar (e : Expr) : SemiringM Var := do
|
|||
varMap := s.varMap.insert { expr := e } var
|
||||
}
|
||||
setTermSemiringId e
|
||||
markAsCommRingTerm e
|
||||
ringExt.markTerm e
|
||||
return var
|
||||
|
||||
def _root_.Lean.Grind.Ring.OfSemiring.Expr.denoteAsRingExpr (e : SemiringExpr) : SemiringM Expr := do
|
||||
|
|
|
|||
|
|
@ -4,11 +4,9 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Offset
|
||||
import Lean.Meta.Tactic.Grind.Arith.Cutsat.EqCnstr
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.Internalize
|
||||
import Lean.Meta.Tactic.Grind.Arith.Linear.Internalize
|
||||
|
||||
public section
|
||||
|
|
@ -19,7 +17,6 @@ namespace Lean.Meta.Grind.Arith
|
|||
def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do
|
||||
Offset.internalize e parent?
|
||||
Cutsat.internalize e parent?
|
||||
CommRing.internalize e parent?
|
||||
Linear.internalize e parent?
|
||||
|
||||
end Lean.Meta.Grind.Arith
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ public import Lean.Meta.Tactic.Grind.PropagatorAttr
|
|||
public import Lean.Meta.Tactic.Grind.Arith.Offset
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Cutsat.LeCnstr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Cutsat.Search
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.EqCnstr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Linear.IneqCnstr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Linear.Search
|
||||
|
||||
|
|
@ -54,9 +53,8 @@ builtin_grind_propagator propagateLT ↓LT.lt := fun e => do
|
|||
|
||||
def check : GoalM Bool := do
|
||||
let c₁ ← Cutsat.check
|
||||
let c₂ ← CommRing.check
|
||||
let c₃ ← Linear.check
|
||||
if c₁ || c₂ || c₃ then
|
||||
if c₁ || c₃ then
|
||||
processNewFacts
|
||||
return true
|
||||
else
|
||||
|
|
|
|||
|
|
@ -4,22 +4,17 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Offset.Types
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Cutsat.Types
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Types
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Linear.Types
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Grind.Arith
|
||||
|
||||
/-- State for the arithmetic procedures. -/
|
||||
structure State where
|
||||
offset : Offset.State := {}
|
||||
cutsat : Cutsat.State := {}
|
||||
ring : CommRing.State := {}
|
||||
linear : Linear.State := {}
|
||||
deriving Inhabited
|
||||
|
||||
|
|
|
|||
|
|
@ -165,31 +165,6 @@ def propagateCutsat : PendingTheoryPropagation → GoalM Unit
|
|||
| .diseqs ps => propagateCutsatDiseqs ps
|
||||
| .none => return ()
|
||||
|
||||
/--
|
||||
Helper function for combining `ENode.ring?` fields and detecting what needs to be
|
||||
propagated to the commutative ring module.
|
||||
-/
|
||||
private def checkCommRingEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
|
||||
match lhsRoot.ring? with
|
||||
| some lhsRing =>
|
||||
if let some rhsRing := rhsRoot.ring? then
|
||||
return .eq lhsRing rhsRing
|
||||
else
|
||||
-- We have to retrieve the node because other fields have been updated
|
||||
let rhsRoot ← getENode rhsRoot.self
|
||||
setENode rhsRoot.self { rhsRoot with ring? := lhsRing }
|
||||
return .diseqs (← getParents rhsRoot.self)
|
||||
| none =>
|
||||
if rhsRoot.ring?.isSome then
|
||||
return .diseqs (← getParents lhsRoot.self)
|
||||
else
|
||||
return .none
|
||||
|
||||
def propagateCommRing : PendingTheoryPropagation → GoalM Unit
|
||||
| .eq lhs rhs => Arith.CommRing.processNewEq lhs rhs
|
||||
| .diseqs ps => propagateCommRingDiseqs ps
|
||||
| _ => return ()
|
||||
|
||||
/--
|
||||
Helper function for combining `ENode.linarith?` fields and detecting what needs to be
|
||||
propagated to the linarith module.
|
||||
|
|
@ -347,7 +322,6 @@ where
|
|||
propagateBeta lams₂ fns₂
|
||||
let offsetTodo ← checkOffsetEq rhsRoot lhsRoot
|
||||
let cutsatTodo ← checkCutsatEq rhsRoot lhsRoot
|
||||
let ringTodo ← checkCommRingEq rhsRoot lhsRoot
|
||||
let linarithTodo ← checkLinarithEq rhsRoot lhsRoot
|
||||
let todo ← Solvers.mergeTerms rhsRoot lhsRoot
|
||||
resetParentsOf lhsRoot.self
|
||||
|
|
@ -362,7 +336,6 @@ where
|
|||
propagateUnitConstFuns lams₁ lams₂
|
||||
propagateOffset offsetTodo
|
||||
propagateCutsat cutsatTodo
|
||||
propagateCommRing ringTodo
|
||||
propagateLinarith linarithTodo
|
||||
todo.propagate
|
||||
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
|
||||
|
|
|
|||
|
|
@ -428,6 +428,7 @@ where
|
|||
-- We do not want to internalize the components of a literal value.
|
||||
mkENode e generation
|
||||
internalizeTheories e parent?
|
||||
Solvers.internalize e parent?
|
||||
else if e.isAppOfArity ``Grind.MatchCond 1 then
|
||||
internalizeMatchCond e generation
|
||||
else e.withApp fun f args => do
|
||||
|
|
|
|||
|
|
@ -178,8 +178,7 @@ private def ppThresholds (c : Grind.Config) : M Unit := do
|
|||
msgs := msgs.push <| .trace { cls := `limit } m!"maximum number of case-splits has been reached, threshold: `(splits := {c.splits})`" #[]
|
||||
if maxGen ≥ c.gen then
|
||||
msgs := msgs.push <| .trace { cls := `limit } m!"maximum term generation has been reached, threshold: `(gen := {c.gen})`" #[]
|
||||
if goal.arith.ring.steps ≥ c.ringSteps then
|
||||
msgs := msgs.push <| .trace { cls := `limit } m!"maximum number of ring steps has been reached, threshold: `(ringSteps := {c.ringSteps})`" #[]
|
||||
msgs ← Arith.CommRing.addThresholdMessage goal c msgs
|
||||
unless msgs.isEmpty do
|
||||
pushMsg <| .trace { cls := `limits } "Thresholds reached" msgs
|
||||
|
||||
|
|
|
|||
|
|
@ -184,7 +184,6 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do
|
|||
if α.isConstOf ``Bool then
|
||||
propagateBoolDiseq e lhs rhs
|
||||
propagateCutsatDiseq lhs rhs
|
||||
propagateCommRingDiseq lhs rhs
|
||||
propagateLinarithDiseq lhs rhs
|
||||
Solvers.propagateDiseqs lhs rhs
|
||||
let thms ← getExtTheorems α
|
||||
|
|
|
|||
|
|
@ -452,11 +452,6 @@ structure ENode where
|
|||
-/
|
||||
cutsat? : Option Expr := none
|
||||
/--
|
||||
The `ring?` field is used to propagate equalities from the `grind` congruence closure module
|
||||
to the comm ring module. Its implementation is similar to the `offset?` field.
|
||||
-/
|
||||
ring? : Option Expr := none
|
||||
/--
|
||||
The `linarith?` field is used to propagate equalities from the `grind` congruence closure module
|
||||
to the linarith module. Its implementation is similar to the `offset?` field.
|
||||
-/
|
||||
|
|
@ -1206,53 +1201,6 @@ def markAsCutsatTerm (e : Expr) : GoalM Unit := do
|
|||
setENode root.self { root with cutsat? := some e }
|
||||
propagateCutsatDiseqs (← getParents root.self)
|
||||
|
||||
/--
|
||||
Notifies the comm ring module that `a = b` where
|
||||
`a` and `b` are terms that have been internalized by this module.
|
||||
-/
|
||||
@[extern "lean_process_ring_eq"] -- forward definition
|
||||
opaque Arith.CommRing.processNewEq (a b : Expr) : GoalM Unit
|
||||
|
||||
/--
|
||||
Notifies the comm ring module that `a ≠ b` where
|
||||
`a` and `b` are terms that have been internalized by this module.
|
||||
-/
|
||||
@[extern "lean_process_ring_diseq"] -- forward definition
|
||||
opaque Arith.CommRing.processNewDiseq (a b : Expr) : GoalM Unit
|
||||
|
||||
/--
|
||||
Given `lhs` and `rhs` that are known to be disequal, checks whether
|
||||
`lhs` and `rhs` have ring terms `e₁` and `e₂` attached to them,
|
||||
and invokes process `Arith.CommRing.processNewDiseq e₁ e₂`
|
||||
-/
|
||||
def propagateCommRingDiseq (lhs rhs : Expr) : GoalM Unit := do
|
||||
let some lhs ← get? lhs | return ()
|
||||
let some rhs ← get? rhs | return ()
|
||||
Arith.CommRing.processNewDiseq lhs rhs
|
||||
where
|
||||
get? (a : Expr) : GoalM (Option Expr) := do
|
||||
return (← getRootENode a).ring?
|
||||
|
||||
/--
|
||||
Traverses disequalities in `parents`, and propagate the ones relevant to the
|
||||
comm ring module.
|
||||
-/
|
||||
def propagateCommRingDiseqs (parents : ParentSet) : GoalM Unit := do
|
||||
forEachDiseq parents propagateCommRingDiseq
|
||||
|
||||
/--
|
||||
Marks `e` as a term of interest to the ring module.
|
||||
If the root of `e`s equivalence class has already a term of interest,
|
||||
a new equality is propagated to the ring module.
|
||||
-/
|
||||
def markAsCommRingTerm (e : Expr) : GoalM Unit := do
|
||||
let root ← getRootENode e
|
||||
if let some e' := root.ring? then
|
||||
Arith.CommRing.processNewEq e e'
|
||||
else
|
||||
setENode root.self { root with ring? := some e }
|
||||
propagateCommRingDiseqs (← getParents root.self)
|
||||
|
||||
/--
|
||||
Notifies the linarith module that `a = b` where
|
||||
`a` and `b` are terms that have been internalized by this module.
|
||||
|
|
@ -1700,6 +1648,8 @@ def Solvers.check : GoalM Bool := do
|
|||
for ext in (← solverExtensionsRef.get) do
|
||||
if (← ext.check) then
|
||||
result := true
|
||||
if result then
|
||||
processNewFacts
|
||||
return result
|
||||
|
||||
/-- Invokes model-based theory combination extensions in all registered solvers. -/
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue