feat: grind solver extensions (part 2) (#10294)
This PR completes the `grind` solver extension design and ports the `grind ac` solver to the new framework. Future PRs will document the API and port the remaining solvers. An additional benefit of the new design is faster build times.
This commit is contained in:
parent
6a8d7cc17c
commit
be1e090833
17 changed files with 220 additions and 136 deletions
|
|
@ -16,8 +16,9 @@ public import Lean.Meta.Tactic.Grind.AC.DenoteExpr
|
|||
public import Lean.Meta.Tactic.Grind.AC.ToExpr
|
||||
public import Lean.Meta.Tactic.Grind.AC.VarRename
|
||||
public import Lean.Meta.Tactic.Grind.AC.PP
|
||||
public import Lean.Meta.Tactic.Grind.AC.Inv
|
||||
public section
|
||||
namespace Lean
|
||||
namespace Lean.Meta.Grind.AC
|
||||
builtin_initialize registerTraceClass `grind.ac
|
||||
builtin_initialize registerTraceClass `grind.ac.assert
|
||||
builtin_initialize registerTraceClass `grind.ac.internalize
|
||||
|
|
@ -29,4 +30,13 @@ builtin_initialize registerTraceClass `grind.debug.ac.check
|
|||
builtin_initialize registerTraceClass `grind.debug.ac.queue
|
||||
builtin_initialize registerTraceClass `grind.debug.ac.superpose
|
||||
builtin_initialize registerTraceClass `grind.debug.ac.eq
|
||||
end Lean
|
||||
|
||||
builtin_initialize
|
||||
acExt.setMethods
|
||||
(internalize := AC.internalize)
|
||||
(newEq := AC.processNewEq)
|
||||
(newDiseq := AC.processNewDiseq)
|
||||
(check := AC.check)
|
||||
(checkInv := AC.checkInvariants)
|
||||
|
||||
end Lean.Meta.Grind.AC
|
||||
|
|
|
|||
|
|
@ -360,8 +360,7 @@ private def EqCnstr.assert (c : EqCnstr) : ACM Unit := do
|
|||
else
|
||||
c.addToQueue
|
||||
|
||||
@[export lean_process_ac_eq]
|
||||
def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do
|
||||
def processNewEq(a b : Expr) : GoalM Unit := withExprs a b do
|
||||
let ea ← asACExpr a
|
||||
let lhs ← norm ea
|
||||
let eb ← asACExpr b
|
||||
|
|
@ -369,8 +368,7 @@ def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do
|
|||
let c ← mkEqCnstr lhs rhs (.core a b ea eb)
|
||||
c.assert
|
||||
|
||||
@[export lean_process_ac_diseq]
|
||||
def processNewDiseqImpl (a b : Expr) : GoalM Unit := withExprs a b do
|
||||
def processNewDiseq (a b : Expr) : GoalM Unit := withExprs a b do
|
||||
let ea ← asACExpr a
|
||||
let lhs ← norm ea
|
||||
let eb ← asACExpr b
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@ partial def reify (e : Expr) : ACM Grind.AC.Expr := do
|
|||
else
|
||||
return .var (← mkVar e)
|
||||
|
||||
@[export lean_grind_ac_internalize]
|
||||
def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do
|
||||
def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
|
||||
unless (← getConfig).ac do return ()
|
||||
unless e.isApp && e.appFn!.isApp do return ()
|
||||
let op := e.appFn!.appFn!
|
||||
|
|
@ -39,6 +38,6 @@ def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do
|
|||
}
|
||||
trace[grind.ac.internalize] "[{id}] {← e'.denoteExpr}"
|
||||
addTermOpId e
|
||||
markAsACTerm e
|
||||
acExt.markTerm e
|
||||
|
||||
end Lean.Meta.Grind.AC
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ private def ppStruct? : M (Option MessageData) := do
|
|||
|
||||
def pp? (goal : Goal) : MetaM (Option MessageData) := do
|
||||
let mut msgs := #[]
|
||||
for struct in goal.ac.structs do
|
||||
for struct in (← acExt.getStateCore goal).structs do
|
||||
let some msg ← ppStruct? |>.run' struct | pure ()
|
||||
msgs := msgs.push msg
|
||||
if msgs.isEmpty then
|
||||
|
|
|
|||
|
|
@ -8,9 +8,7 @@ prelude
|
|||
public import Init.Core
|
||||
public import Init.Grind.AC
|
||||
public import Std.Data.HashMap
|
||||
public import Lean.Expr
|
||||
public import Lean.Data.PersistentArray
|
||||
public import Lean.Meta.Tactic.Grind.ExprPtr
|
||||
public import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.AC.Seq
|
||||
public section
|
||||
namespace Lean.Meta.Grind.AC
|
||||
|
|
@ -98,7 +96,7 @@ structure Struct where
|
|||
/-- Mapping from Lean expressions to their representations as `AC.Expr` -/
|
||||
denote : PHashMap ExprPtr AC.Expr := {}
|
||||
/-- `denoteEntries` is `denote` as a `PArray` for deterministic traversal. -/
|
||||
denoteEntries : PArray (Expr × AC.Expr) := {}
|
||||
denoteEntries : PArray (Expr × AC.Expr) := {}
|
||||
/-- Equations to process. -/
|
||||
queue : Queue := {}
|
||||
/-- Processed equations. -/
|
||||
|
|
@ -130,4 +128,6 @@ structure State where
|
|||
steps := 0
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize acExt : SolverExtension State ← registerSolverExtension (return {})
|
||||
|
||||
end Lean.Meta.Grind.AC
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ Authors: Leonardo de Moura
|
|||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Types
|
||||
public import Lean.Meta.Tactic.Grind.AC.Types
|
||||
public import Lean.Meta.Tactic.Grind.ProveEq
|
||||
public import Lean.Meta.Tactic.Grind.SynthInstance
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
|
||||
|
|
@ -15,10 +15,10 @@ namespace Lean.Meta.Grind.AC
|
|||
open Lean.Grind
|
||||
|
||||
def get' : GoalM State := do
|
||||
return (← get).ac
|
||||
acExt.getState
|
||||
|
||||
@[inline] def modify' (f : State → State) : GoalM Unit := do
|
||||
modify fun s => { s with ac := f s.ac }
|
||||
acExt.modifyState f
|
||||
|
||||
def checkMaxSteps : GoalM Bool := do
|
||||
return (← get').steps >= (← getConfig).acSteps
|
||||
|
|
@ -115,7 +115,7 @@ def mkVar (e : Expr) : ACM AC.Var := do
|
|||
varMap := s.varMap.insert { expr := e } var
|
||||
}
|
||||
addTermOpId e
|
||||
markAsACTerm e
|
||||
acExt.markTerm e
|
||||
return var
|
||||
|
||||
def getOpId? (op : Expr) : GoalM (Option Nat) := do
|
||||
|
|
|
|||
|
|
@ -7,12 +7,11 @@ module
|
|||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.Arith.Main
|
||||
import Lean.Meta.Tactic.Grind.AC.Eq
|
||||
namespace Lean.Meta.Grind
|
||||
/--
|
||||
Checks whether satellite solvers can make progress (e.g., detect unsatisfiability, propagate equations, etc)
|
||||
-/
|
||||
public def check : GoalM Bool := do
|
||||
Arith.check <||> AC.check
|
||||
Arith.check
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
|
|
|||
|
|
@ -215,31 +215,6 @@ def propagateLinarith : PendingTheoryPropagation → GoalM Unit
|
|||
| .diseqs ps => propagateLinarithDiseqs ps
|
||||
| _ => return ()
|
||||
|
||||
/--
|
||||
Helper function for combining `ENode.ac?` fields and detecting what needs to be
|
||||
propagated to the ac module.
|
||||
-/
|
||||
private def checkACEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
|
||||
match lhsRoot.ac? with
|
||||
| some lhs =>
|
||||
if let some rhs := rhsRoot.ac? then
|
||||
return .eq lhs rhs
|
||||
else
|
||||
-- We have to retrieve the node because other fields have been updated
|
||||
let rhsRoot ← getENode rhsRoot.self
|
||||
setENode rhsRoot.self { rhsRoot with ac? := lhs }
|
||||
return .diseqs (← getParents rhsRoot.self)
|
||||
| none =>
|
||||
if rhsRoot.ac?.isSome then
|
||||
return .diseqs (← getParents lhsRoot.self)
|
||||
else
|
||||
return .none
|
||||
|
||||
def propagateAC : PendingTheoryPropagation → GoalM Unit
|
||||
| .eq lhs rhs => AC.processNewEq lhs rhs
|
||||
| .diseqs ps => propagateACDiseqs ps
|
||||
| _ => return ()
|
||||
|
||||
/--
|
||||
Tries to apply beta-reduction using the parent applications of the functions in `fns` with
|
||||
the lambda expressions in `lams`.
|
||||
|
|
@ -374,7 +349,7 @@ where
|
|||
let cutsatTodo ← checkCutsatEq rhsRoot lhsRoot
|
||||
let ringTodo ← checkCommRingEq rhsRoot lhsRoot
|
||||
let linarithTodo ← checkLinarithEq rhsRoot lhsRoot
|
||||
let ACTodo ← checkACEq rhsRoot lhsRoot
|
||||
let todo ← Solvers.mergeTerms rhsRoot lhsRoot
|
||||
resetParentsOf lhsRoot.self
|
||||
copyParentsTo parents rhsNode.root
|
||||
unless (← isInconsistent) do
|
||||
|
|
@ -389,7 +364,7 @@ where
|
|||
propagateCutsat cutsatTodo
|
||||
propagateCommRing ringTodo
|
||||
propagateLinarith linarithTodo
|
||||
propagateAC ACTodo
|
||||
todo.propagate
|
||||
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
|
||||
let isFalseRoot ← isFalseExpr rootNew
|
||||
traverseEqc lhs fun n => do
|
||||
|
|
|
|||
|
|
@ -21,8 +21,6 @@ import Lean.Meta.Tactic.Grind.MarkNestedSubsingletons
|
|||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
@[extern "lean_grind_ac_internalize"] -- forward definition
|
||||
opaque AC.internalize (e : Expr) (parent? : Option Expr) : GoalM Unit
|
||||
@[extern "lean_grind_arith_internalize"] -- forward definition
|
||||
opaque Arith.internalize (e : Expr) (parent? : Option Expr) : GoalM Unit
|
||||
|
||||
|
|
@ -366,7 +364,6 @@ private def tryEta (e : Expr) (generation : Nat) : GoalM Unit := do
|
|||
|
||||
private def internalizeTheories (e : Expr) (parent? : Option Expr := none) : GoalM Unit := do
|
||||
Arith.internalize e parent?
|
||||
AC.internalize e parent?
|
||||
|
||||
@[export lean_grind_internalize]
|
||||
private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := withIncRecDepth do
|
||||
|
|
@ -381,6 +378,7 @@ private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Opt
|
|||
Otherwise, it will not be able to propagate that `a + 1 = 1` when `a = 0`
|
||||
-/
|
||||
internalizeTheories e parent?
|
||||
Solvers.internalize e parent?
|
||||
else
|
||||
go
|
||||
propagateEtaStruct e generation
|
||||
|
|
@ -467,6 +465,7 @@ where
|
|||
registerParent e arg
|
||||
addCongrTable e
|
||||
internalizeTheories e parent?
|
||||
Solvers.internalize e parent?
|
||||
propagateUp e
|
||||
propagateBetaForNewApp e
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ public import Lean.Meta.Tactic.Grind.Types
|
|||
import Lean.Meta.Tactic.Grind.Proof
|
||||
import Lean.Meta.Tactic.Grind.MatchCond
|
||||
import Lean.Meta.Tactic.Grind.Arith.Inv
|
||||
import Lean.Meta.Tactic.Grind.AC.Inv
|
||||
namespace Lean.Meta.Grind
|
||||
/-!
|
||||
Debugging support code for checking basic invariants.
|
||||
|
|
@ -126,7 +125,7 @@ public def checkInvariants (expensive := false) : GoalM Unit := do
|
|||
if expensive then
|
||||
checkPtrEqImpliesStructEq
|
||||
Arith.checkInvariants
|
||||
AC.checkInvariants
|
||||
Solvers.checkInvariants
|
||||
if expensive && grind.debug.proofs.get (← getOptions) then
|
||||
checkProofs
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ private partial def solve (generation : Nat) : SearchM Bool := withIncRecDepth d
|
|||
return false -- `splitNext` should have been configured to not create choice points
|
||||
if (← getGoal).inconsistent then
|
||||
return true
|
||||
if (← intros' generation <||> assertAll <||> check <||> splitNext <||> ematch) then
|
||||
if (← intros' generation <||> assertAll <||> check <||> Solvers.check <||> splitNext <||> ematch) then
|
||||
solve generation
|
||||
else
|
||||
return false
|
||||
|
|
|
|||
|
|
@ -109,7 +109,8 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
|
|||
let thmMap := params.ematch
|
||||
let casesTypes := params.casesTypes
|
||||
let clean ← mkCleanState mvarId params
|
||||
GoalM.run' { mvarId, ematch.thmMap := thmMap, split.casesTypes := casesTypes, clean } do
|
||||
let sstates ← Solvers.mkInitialStates
|
||||
GoalM.run' { mvarId, ematch.thmMap := thmMap, split.casesTypes := casesTypes, clean, sstates } do
|
||||
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
mkENodeCore btrueExpr (interpreted := false) (ctor := true) (generation := 0)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Grind.Util
|
||||
public import Init.Grind.PP
|
||||
|
|
@ -14,7 +13,6 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.PP
|
|||
public import Lean.Meta.Tactic.Grind.Arith.Linear.PP
|
||||
public import Lean.Meta.Tactic.Grind.AC.PP
|
||||
import Lean.PrettyPrinter
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do
|
|||
propagateCutsatDiseq lhs rhs
|
||||
propagateCommRingDiseq lhs rhs
|
||||
propagateLinarithDiseq lhs rhs
|
||||
propagateACDiseq lhs rhs
|
||||
Solvers.propagateDiseqs lhs rhs
|
||||
let thms ← getExtTheorems α
|
||||
if !thms.isEmpty then
|
||||
/-
|
||||
|
|
|
|||
|
|
@ -51,8 +51,8 @@ where
|
|||
intros gen
|
||||
else
|
||||
break
|
||||
if (← assertAll <||> check <||> ematch <||> lookahead <||> splitNext <||> Arith.Cutsat.mbtc
|
||||
<||> Arith.Linear.mbtc <||> tryFallback) then
|
||||
if (← assertAll <||> check <||> Solvers.check <||> ematch <||> lookahead <||> splitNext <||> Arith.Cutsat.mbtc
|
||||
<||> Arith.Linear.mbtc <||> Solvers.mbtc <||> tryFallback) then
|
||||
continue
|
||||
return some (← getGoal) -- failed
|
||||
return none -- solved
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ public import Lean.Meta.Tactic.Grind.AlphaShareCommon
|
|||
public import Lean.Meta.Tactic.Grind.Attr
|
||||
public import Lean.Meta.Tactic.Grind.ExtAttr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Types
|
||||
public import Lean.Meta.Tactic.Grind.AC.Types
|
||||
public import Lean.Meta.Tactic.Grind.EMatchTheorem
|
||||
meta import Lean.Parser.Do
|
||||
import Lean.Meta.Match.MatchEqsExt
|
||||
|
|
@ -379,6 +378,23 @@ private def expandReportIssueMacro (s : Syntax) : MacroM (TSyntax `doElem) := do
|
|||
macro "reportIssue!" s:(interpolatedStr(term) <|> term) : doElem => do
|
||||
expandReportIssueMacro s.raw
|
||||
|
||||
/--
|
||||
Each E-node may have "solver terms" attached to them.
|
||||
Each term is an element of the equivalence class that the
|
||||
solver cares about. Each solver is responsible for marking the terms they care about.
|
||||
The `grind` core propagates equalities and disequalities to the theory solvers
|
||||
using these "marked" terms. The root of the equivalence class
|
||||
contains a list of representatives sorted by solver id. Note that many E-nodes
|
||||
do not have any solver terms attached to them.
|
||||
|
||||
"Solver terms" are referenced as "theory variables" in the SMT literature.
|
||||
The SMT solver Z3 uses a similar representation.
|
||||
-/
|
||||
inductive SolverTerms where
|
||||
| nil
|
||||
| next (solverId : Nat) (e : Expr) (rest : SolverTerms)
|
||||
deriving Inhabited, Repr
|
||||
|
||||
/--
|
||||
Stores information for a node in the E-graph.
|
||||
Each internalized expression `e` has an `ENode` associated with it.
|
||||
|
|
@ -445,13 +461,8 @@ structure ENode where
|
|||
to the linarith module. Its implementation is similar to the `offset?` field.
|
||||
-/
|
||||
linarith? : Option Expr := none
|
||||
/--
|
||||
The `ac?` field is used to propagate equalities from the `grind` congruence closure module
|
||||
to the ac module. Its implementation is similar to the `offset?` field.
|
||||
-/
|
||||
ac? : Option Expr := none
|
||||
-- Remark: we expect to have builtin support for offset constraints, cutsat, comm ring, linarith, and ac.
|
||||
-- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3.
|
||||
/-- Solver terms attached to this E-node. -/
|
||||
sTerms : SolverTerms := .nil
|
||||
deriving Inhabited, Repr
|
||||
|
||||
def ENode.isRoot (n : ENode) :=
|
||||
|
|
@ -773,8 +784,6 @@ structure Goal where
|
|||
split : Split.State := {}
|
||||
/-- State of arithmetic procedures. -/
|
||||
arith : Arith.State := {}
|
||||
/-- State of the ac solver. -/
|
||||
ac : AC.State := {}
|
||||
/-- State of the clean name generator. -/
|
||||
clean : Clean.State := {}
|
||||
/-- Solver states. -/
|
||||
|
|
@ -1171,14 +1180,10 @@ and invokes process `Arith.Cutsat.processNewDiseq e₁ e₂`
|
|||
def propagateCutsatDiseq (lhs rhs : Expr) : GoalM Unit := do
|
||||
let some lhs ← get? lhs | return ()
|
||||
let some rhs ← get? rhs | return ()
|
||||
-- Recall that core can take care of disequalities of the form `1≠2`.
|
||||
unless isNum lhs && isNum rhs do
|
||||
Arith.Cutsat.processNewDiseq lhs rhs
|
||||
Arith.Cutsat.processNewDiseq lhs rhs
|
||||
where
|
||||
get? (a : Expr) : GoalM (Option Expr) := do
|
||||
let root ← getRootENode a
|
||||
if isNum root.self then
|
||||
return some root.self
|
||||
return root.cutsat?
|
||||
|
||||
/--
|
||||
|
|
@ -1295,53 +1300,6 @@ def markAsLinarithTerm (e : Expr) : GoalM Unit := do
|
|||
setENode root.self { root with linarith? := some e }
|
||||
propagateLinarithDiseqs (← getParents root.self)
|
||||
|
||||
/--
|
||||
Notifies the ac module that `a = b` where
|
||||
`a` and `b` are terms that have been internalized by this module.
|
||||
-/
|
||||
@[extern "lean_process_ac_eq"] -- forward definition
|
||||
opaque AC.processNewEq (a b : Expr) : GoalM Unit
|
||||
|
||||
/--
|
||||
Notifies the ac module that `a ≠ b` where
|
||||
`a` and `b` are terms that have been internalized by this module.
|
||||
-/
|
||||
@[extern "lean_process_ac_diseq"] -- forward definition
|
||||
opaque AC.processNewDiseq (a b : Expr) : GoalM Unit
|
||||
|
||||
/--
|
||||
Given `lhs` and `rhs` that are known to be disequal, checks whether
|
||||
`lhs` and `rhs` have ac terms `e₁` and `e₂` attached to them,
|
||||
and invokes process `AC.processNewDiseq e₁ e₂`
|
||||
-/
|
||||
def propagateACDiseq (lhs rhs : Expr) : GoalM Unit := do
|
||||
let some lhs ← get? lhs | return ()
|
||||
let some rhs ← get? rhs | return ()
|
||||
AC.processNewDiseq lhs rhs
|
||||
where
|
||||
get? (a : Expr) : GoalM (Option Expr) := do
|
||||
return (← getRootENode a).ac?
|
||||
|
||||
/--
|
||||
Traverses disequalities in `parents`, and propagate the ones relevant to the
|
||||
ac module.
|
||||
-/
|
||||
def propagateACDiseqs (parents : ParentSet) : GoalM Unit := do
|
||||
forEachDiseq parents propagateACDiseq
|
||||
|
||||
/--
|
||||
Marks `e` as a term of interest to the ac module.
|
||||
If the root of `e`s equivalence class has already a term of interest,
|
||||
a new equality is propagated to the ac module.
|
||||
-/
|
||||
def markAsACTerm (e : Expr) : GoalM Unit := do
|
||||
let root ← getRootENode e
|
||||
if let some e' := root.ac? then
|
||||
AC.processNewEq e e'
|
||||
else
|
||||
setENode root.self { root with ac? := some e }
|
||||
propagateACDiseqs (← getParents root.self)
|
||||
|
||||
/-- Returns `true` is `e` is the root of its congruence class. -/
|
||||
def isCongrRoot (e : Expr) : GoalM Bool := do
|
||||
return (← getENode e).isCongrRoot
|
||||
|
|
@ -1659,29 +1617,52 @@ structure SolverExtension (σ : Type) where private mk ::
|
|||
checkInv : GoalM Unit
|
||||
deriving Inhabited
|
||||
|
||||
private builtin_initialize solverExtensionsRef : IO.Ref (Array (SolverExtension EnvExtensionState)) ← IO.mkRef #[]
|
||||
private builtin_initialize solverExtensionsRef : IO.Ref (Array (SolverExtension SolverExtensionState)) ← IO.mkRef #[]
|
||||
|
||||
/--
|
||||
Registers a new solver extension for `grind`.
|
||||
Solver extensions can only be registered during initialization.
|
||||
Reason: We do not use any synchronization primitive to access `solverExtensionsRef`.
|
||||
-/
|
||||
def registerSolverExtension {σ : Type}
|
||||
(mkInitial : IO σ)
|
||||
(internalize : Expr → (parent? : Option Expr) → GoalM Unit)
|
||||
(newEq : Expr → Expr → GoalM Unit)
|
||||
(newDiseq : Expr → Expr → GoalM Unit := fun _ _ => return ())
|
||||
(mbtc : GoalM Bool := return false)
|
||||
(check : GoalM Bool := return false)
|
||||
(checkInv : GoalM Unit := return ()) : IO (SolverExtension σ) := do
|
||||
def registerSolverExtension {σ : Type} (mkInitial : IO σ) : IO (SolverExtension σ) := do
|
||||
unless (← initializing) do
|
||||
throw (IO.userError "failed to register `grind` solver, extensions can only be registered during initialization")
|
||||
let exts ← solverExtensionsRef.get
|
||||
let id := exts.size
|
||||
let ext : SolverExtension σ := { id, mkInitial, internalize, newEq, newDiseq, check, checkInv, mbtc }
|
||||
let ext : SolverExtension σ := {
|
||||
id, mkInitial
|
||||
internalize := fun _ _ => return ()
|
||||
newEq := fun _ _ => return ()
|
||||
newDiseq := fun _ _ => return ()
|
||||
check := fun _ _ => return false
|
||||
checkInv := fun _ _ => return ()
|
||||
mbtc := fun _ _ => return false
|
||||
}
|
||||
solverExtensionsRef.modify fun exts => exts.push (unsafe unsafeCast ext)
|
||||
return ext
|
||||
|
||||
def mkInitialExtStates : IO (Array EnvExtensionState) := do
|
||||
/--
|
||||
Sets methods/handlers for solver extension `ext`.
|
||||
Solver extension methods can only be registered during initialization.
|
||||
Reason: We do not use any synchronization primitive to access `solverExtensionsRef`.
|
||||
-/
|
||||
def SolverExtension.setMethods (ext : SolverExtension σ)
|
||||
(internalize : Expr → (parent? : Option Expr) → GoalM Unit := fun _ _ => return ())
|
||||
(newEq : Expr → Expr → GoalM Unit := fun _ _ => return ())
|
||||
(newDiseq : Expr → Expr → GoalM Unit := fun _ _ => return ())
|
||||
(mbtc : GoalM Bool := return false)
|
||||
(check : GoalM Bool := return false)
|
||||
(checkInv : GoalM Unit := return ()) : IO Unit := do
|
||||
unless (← initializing) do
|
||||
throw (IO.userError "failed to register `grind` solver, extensions can only be registered during initialization")
|
||||
unless ext.id < (← solverExtensionsRef.get).size do
|
||||
throw (IO.userError "failed to register `grind` solver methods, invalid solver id")
|
||||
solverExtensionsRef.modify fun exts => exts.modify ext.id fun s => { s with
|
||||
internalize, newEq, newDiseq, mbtc, check, checkInv
|
||||
}
|
||||
|
||||
/-- Returns initial state for registered solvers. -/
|
||||
def Solvers.mkInitialStates : IO (Array SolverExtensionState) := do
|
||||
let exts ← solverExtensionsRef.get
|
||||
exts.mapM fun ext => ext.mkInitial
|
||||
|
||||
|
|
@ -1696,11 +1677,14 @@ private unsafe def SolverExtension.modifyStateImpl (ext : SolverExtension σ) (f
|
|||
@[implemented_by SolverExtension.modifyStateImpl]
|
||||
opaque SolverExtension.modifyState (ext : SolverExtension σ) (f : σ → σ) : GoalM Unit
|
||||
|
||||
private unsafe def SolverExtension.getStateImpl (ext : SolverExtension σ) : GoalM σ := do
|
||||
return unsafeCast (← get).sstates[ext.id]!
|
||||
private unsafe def SolverExtension.getStateCoreImpl (ext : SolverExtension σ) (goal : Goal) : IO σ :=
|
||||
return unsafeCast goal.sstates[ext.id]!
|
||||
|
||||
@[implemented_by SolverExtension.getStateImpl]
|
||||
opaque SolverExtension.getState (ext : SolverExtension σ) : GoalM σ
|
||||
@[implemented_by SolverExtension.getStateCoreImpl]
|
||||
opaque SolverExtension.getStateCore (ext : SolverExtension σ) (goal : Goal) : IO σ
|
||||
|
||||
def SolverExtension.getState (ext : SolverExtension σ) : GoalM σ := do
|
||||
ext.getStateCore (← get)
|
||||
|
||||
/-- Internalizes given expression in all registered solvers. -/
|
||||
def Solvers.internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
|
||||
|
|
@ -1726,4 +1710,125 @@ def Solvers.mbtc : GoalM Bool := do
|
|||
result := true
|
||||
return result
|
||||
|
||||
/--
|
||||
Given a new disequality `lhs ≠ rhs`, propagates it to relevant theories.
|
||||
-/
|
||||
def Solvers.propagateDiseqs (lhs rhs : Expr) : GoalM Unit := do
|
||||
go (← getRootENode lhs).sTerms (← getRootENode rhs).sTerms
|
||||
where
|
||||
go (lhsTerms rhsTerms : SolverTerms) : GoalM Unit := do
|
||||
match lhsTerms, rhsTerms with
|
||||
| .nil, _ => return ()
|
||||
| _, .nil => return ()
|
||||
| .next id₁ lhs lhsTerms, .next id₂ rhs rhsTerms =>
|
||||
if id₁ == id₂ then
|
||||
(← solverExtensionsRef.get)[id₁]!.newDiseq lhs rhs
|
||||
go lhsTerms rhsTerms
|
||||
else if id₁ < id₂ then
|
||||
go lhsTerms (.next id₂ rhs rhsTerms)
|
||||
else
|
||||
go (.next id₁ lhs lhsTerms) rhsTerms
|
||||
|
||||
private def propagateDiseqOf (id : Nat) (lhs rhs : Expr) : GoalM Unit := do
|
||||
visitLhs (← getRootENode lhs).sTerms
|
||||
where
|
||||
visitLhs (sTerms : SolverTerms) : GoalM Unit := do
|
||||
match sTerms with
|
||||
| .nil => return ()
|
||||
| .next id' e sTerms =>
|
||||
if id == id' then
|
||||
visitRhs e (← getRootENode rhs).sTerms
|
||||
else if id < id' then
|
||||
return ()
|
||||
else
|
||||
visitLhs sTerms
|
||||
|
||||
visitRhs (lhsTerm : Expr) (sTerms : SolverTerms) : GoalM Unit := do
|
||||
match sTerms with
|
||||
| .nil => return ()
|
||||
| .next id' e sTerms =>
|
||||
if id == id' then
|
||||
let rhsTerm := e
|
||||
(← solverExtensionsRef.get)[id]!.newDiseq lhsTerm rhsTerm
|
||||
else if id < id' then
|
||||
return ()
|
||||
else
|
||||
visitRhs lhsTerm sTerms
|
||||
|
||||
def isSameSolverTerms (a b : SolverTerms) : Bool :=
|
||||
unsafe ptrEq a b
|
||||
|
||||
def SolverExtension.markTerm (ext : SolverExtension σ) (e : Expr) : GoalM Unit := do
|
||||
let root ← getRootENode e
|
||||
let id := ext.id
|
||||
let rec go (sTerms : SolverTerms) : GoalM SolverTerms := do
|
||||
match sTerms with
|
||||
| .nil => return .next id e .nil
|
||||
| .next id' e' sTerms' =>
|
||||
if id == id' then
|
||||
(← solverExtensionsRef.get)[id]!.newEq e e'
|
||||
return sTerms
|
||||
else if id < id' then
|
||||
return .next id e sTerms
|
||||
else
|
||||
let sTermsNew ← go sTerms'
|
||||
if isSameSolverTerms sTermsNew sTerms' then
|
||||
return sTerms
|
||||
else
|
||||
return .next id' e' sTermsNew
|
||||
let sTermsNew ← go root.sTerms
|
||||
unless isSameSolverTerms sTermsNew root.sTerms do
|
||||
setENode root.self { root with sTerms := sTermsNew }
|
||||
forEachDiseq (← getParents root.self) (propagateDiseqOf id)
|
||||
|
||||
private inductive PendingSolverPropagationsData where
|
||||
| nil
|
||||
| eq (solverId : Nat) (lhs rhs : Expr) (rest : PendingSolverPropagationsData)
|
||||
| diseqs (solverId : Nat) (ps : ParentSet) (rest : PendingSolverPropagationsData)
|
||||
|
||||
structure PendingSolverPropagations where private mk ::
|
||||
private data : PendingSolverPropagationsData
|
||||
|
||||
def Solvers.mergeTerms (rhsRoot lhsRoot : ENode) : GoalM PendingSolverPropagations := do
|
||||
let (sTerms, data) ← go rhsRoot.sTerms lhsRoot.sTerms
|
||||
unless sTerms matches .nil do
|
||||
-- We have to retrieve the node because other fields have been updated
|
||||
let rhsRoot ← getENode rhsRoot.self
|
||||
setENode rhsRoot.self { rhsRoot with sTerms }
|
||||
return { data }
|
||||
where
|
||||
toPendingDiseqs (sTerms : SolverTerms) (ps : ParentSet) : PendingSolverPropagationsData :=
|
||||
match sTerms with
|
||||
| .nil => .nil
|
||||
| .next id _ sTerms => .diseqs id ps (toPendingDiseqs sTerms ps)
|
||||
|
||||
go (rhsTerms : SolverTerms) (lhsTerms : SolverTerms) : GoalM (SolverTerms × PendingSolverPropagationsData) := do
|
||||
match rhsTerms, lhsTerms with
|
||||
| .nil, .nil => return (.nil, .nil)
|
||||
| .nil, .next .. => return (lhsTerms, toPendingDiseqs lhsTerms (← getParents rhsRoot.self))
|
||||
| .next .., .nil => return (rhsTerms, toPendingDiseqs rhsTerms (← getParents lhsRoot.self))
|
||||
| .next id₁ rhs rhsTerms, .next id₂ lhs lhsTerms =>
|
||||
if id₁ == id₂ then
|
||||
let (s, p) ← go rhsTerms lhsTerms
|
||||
return (.next id₁ rhs s, .eq id₁ lhs rhs p)
|
||||
else if id₁ < id₂ then
|
||||
let (s, p) ← go rhsTerms (.next id₂ lhs lhsTerms)
|
||||
return (.next id₁ rhs s, .diseqs id₁ (← getParents lhsRoot.self) p)
|
||||
else
|
||||
let (s, p) ← go (.next id₁ rhs rhsTerms) lhsTerms
|
||||
return (.next id₂ lhs s, .diseqs id₂ (← getParents rhsRoot.self) p)
|
||||
|
||||
def PendingSolverPropagations.propagate (p : PendingSolverPropagations) : GoalM Unit := do
|
||||
go p.data
|
||||
where
|
||||
go (p : PendingSolverPropagationsData) : GoalM Unit := do
|
||||
match p with
|
||||
| .nil => return ()
|
||||
| .eq solverId lhs rhs rest =>
|
||||
(← solverExtensionsRef.get)[solverId]!.newEq lhs rhs
|
||||
go rest
|
||||
| .diseqs solverId parentSet rest =>
|
||||
forEachDiseq parentSet (propagateDiseqOf solverId)
|
||||
go rest
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
// update me!
|
||||
#include "util/options.h"
|
||||
|
||||
namespace lean {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue