feat: grind solver extensions (part 2) (#10294)

This PR completes the `grind` solver extension design and ports the
`grind ac` solver to the new framework. Future PRs will document the API
and port the remaining solvers. An additional benefit of the new design
is faster build times.
This commit is contained in:
Leonardo de Moura 2025-09-07 18:11:05 -07:00 committed by GitHub
parent 6a8d7cc17c
commit be1e090833
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 220 additions and 136 deletions

View file

@ -16,8 +16,9 @@ public import Lean.Meta.Tactic.Grind.AC.DenoteExpr
public import Lean.Meta.Tactic.Grind.AC.ToExpr
public import Lean.Meta.Tactic.Grind.AC.VarRename
public import Lean.Meta.Tactic.Grind.AC.PP
public import Lean.Meta.Tactic.Grind.AC.Inv
public section
namespace Lean
namespace Lean.Meta.Grind.AC
builtin_initialize registerTraceClass `grind.ac
builtin_initialize registerTraceClass `grind.ac.assert
builtin_initialize registerTraceClass `grind.ac.internalize
@ -29,4 +30,13 @@ builtin_initialize registerTraceClass `grind.debug.ac.check
builtin_initialize registerTraceClass `grind.debug.ac.queue
builtin_initialize registerTraceClass `grind.debug.ac.superpose
builtin_initialize registerTraceClass `grind.debug.ac.eq
end Lean
builtin_initialize
acExt.setMethods
(internalize := AC.internalize)
(newEq := AC.processNewEq)
(newDiseq := AC.processNewDiseq)
(check := AC.check)
(checkInv := AC.checkInvariants)
end Lean.Meta.Grind.AC

View file

@ -360,8 +360,7 @@ private def EqCnstr.assert (c : EqCnstr) : ACM Unit := do
else
c.addToQueue
@[export lean_process_ac_eq]
def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do
def processNewEq(a b : Expr) : GoalM Unit := withExprs a b do
let ea ← asACExpr a
let lhs ← norm ea
let eb ← asACExpr b
@ -369,8 +368,7 @@ def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do
let c ← mkEqCnstr lhs rhs (.core a b ea eb)
c.assert
@[export lean_process_ac_diseq]
def processNewDiseqImpl (a b : Expr) : GoalM Unit := withExprs a b do
def processNewDiseq (a b : Expr) : GoalM Unit := withExprs a b do
let ea ← asACExpr a
let lhs ← norm ea
let eb ← asACExpr b

View file

@ -22,8 +22,7 @@ partial def reify (e : Expr) : ACM Grind.AC.Expr := do
else
return .var (← mkVar e)
@[export lean_grind_ac_internalize]
def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do
def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
unless (← getConfig).ac do return ()
unless e.isApp && e.appFn!.isApp do return ()
let op := e.appFn!.appFn!
@ -39,6 +38,6 @@ def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do
}
trace[grind.ac.internalize] "[{id}] {← e'.denoteExpr}"
addTermOpId e
markAsACTerm e
acExt.markTerm e
end Lean.Meta.Grind.AC

View file

@ -53,7 +53,7 @@ private def ppStruct? : M (Option MessageData) := do
def pp? (goal : Goal) : MetaM (Option MessageData) := do
let mut msgs := #[]
for struct in goal.ac.structs do
for struct in (← acExt.getStateCore goal).structs do
let some msg ← ppStruct? |>.run' struct | pure ()
msgs := msgs.push msg
if msgs.isEmpty then

View file

@ -8,9 +8,7 @@ prelude
public import Init.Core
public import Init.Grind.AC
public import Std.Data.HashMap
public import Lean.Expr
public import Lean.Data.PersistentArray
public import Lean.Meta.Tactic.Grind.ExprPtr
public import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.AC.Seq
public section
namespace Lean.Meta.Grind.AC
@ -98,7 +96,7 @@ structure Struct where
/-- Mapping from Lean expressions to their representations as `AC.Expr` -/
denote : PHashMap ExprPtr AC.Expr := {}
/-- `denoteEntries` is `denote` as a `PArray` for deterministic traversal. -/
denoteEntries : PArray (Expr × AC.Expr) := {}
denoteEntries : PArray (Expr × AC.Expr) := {}
/-- Equations to process. -/
queue : Queue := {}
/-- Processed equations. -/
@ -130,4 +128,6 @@ structure State where
steps := 0
deriving Inhabited
builtin_initialize acExt : SolverExtension State ← registerSolverExtension (return {})
end Lean.Meta.Grind.AC

View file

@ -5,7 +5,7 @@ Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.Types
public import Lean.Meta.Tactic.Grind.AC.Types
public import Lean.Meta.Tactic.Grind.ProveEq
public import Lean.Meta.Tactic.Grind.SynthInstance
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
@ -15,10 +15,10 @@ namespace Lean.Meta.Grind.AC
open Lean.Grind
def get' : GoalM State := do
return (← get).ac
acExt.getState
@[inline] def modify' (f : State → State) : GoalM Unit := do
modify fun s => { s with ac := f s.ac }
acExt.modifyState f
def checkMaxSteps : GoalM Bool := do
return (← get').steps >= (← getConfig).acSteps
@ -115,7 +115,7 @@ def mkVar (e : Expr) : ACM AC.Var := do
varMap := s.varMap.insert { expr := e } var
}
addTermOpId e
markAsACTerm e
acExt.markTerm e
return var
def getOpId? (op : Expr) : GoalM (Option Nat) := do

View file

@ -7,12 +7,11 @@ module
prelude
public import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Arith.Main
import Lean.Meta.Tactic.Grind.AC.Eq
namespace Lean.Meta.Grind
/--
Checks whether satellite solvers can make progress (e.g., detect unsatisfiability, propagate equations, etc)
-/
public def check : GoalM Bool := do
Arith.check <||> AC.check
Arith.check
namespace Lean.Meta.Grind

View file

@ -215,31 +215,6 @@ def propagateLinarith : PendingTheoryPropagation → GoalM Unit
| .diseqs ps => propagateLinarithDiseqs ps
| _ => return ()
/--
Helper function for combining `ENode.ac?` fields and detecting what needs to be
propagated to the ac module.
-/
private def checkACEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
match lhsRoot.ac? with
| some lhs =>
if let some rhs := rhsRoot.ac? then
return .eq lhs rhs
else
-- We have to retrieve the node because other fields have been updated
let rhsRoot ← getENode rhsRoot.self
setENode rhsRoot.self { rhsRoot with ac? := lhs }
return .diseqs (← getParents rhsRoot.self)
| none =>
if rhsRoot.ac?.isSome then
return .diseqs (← getParents lhsRoot.self)
else
return .none
def propagateAC : PendingTheoryPropagation → GoalM Unit
| .eq lhs rhs => AC.processNewEq lhs rhs
| .diseqs ps => propagateACDiseqs ps
| _ => return ()
/--
Tries to apply beta-reduction using the parent applications of the functions in `fns` with
the lambda expressions in `lams`.
@ -374,7 +349,7 @@ where
let cutsatTodo ← checkCutsatEq rhsRoot lhsRoot
let ringTodo ← checkCommRingEq rhsRoot lhsRoot
let linarithTodo ← checkLinarithEq rhsRoot lhsRoot
let ACTodo ← checkACEq rhsRoot lhsRoot
let todo ← Solvers.mergeTerms rhsRoot lhsRoot
resetParentsOf lhsRoot.self
copyParentsTo parents rhsNode.root
unless (← isInconsistent) do
@ -389,7 +364,7 @@ where
propagateCutsat cutsatTodo
propagateCommRing ringTodo
propagateLinarith linarithTodo
propagateAC ACTodo
todo.propagate
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
let isFalseRoot ← isFalseExpr rootNew
traverseEqc lhs fun n => do

View file

@ -21,8 +21,6 @@ import Lean.Meta.Tactic.Grind.MarkNestedSubsingletons
public section
namespace Lean.Meta.Grind
@[extern "lean_grind_ac_internalize"] -- forward definition
opaque AC.internalize (e : Expr) (parent? : Option Expr) : GoalM Unit
@[extern "lean_grind_arith_internalize"] -- forward definition
opaque Arith.internalize (e : Expr) (parent? : Option Expr) : GoalM Unit
@ -366,7 +364,6 @@ private def tryEta (e : Expr) (generation : Nat) : GoalM Unit := do
private def internalizeTheories (e : Expr) (parent? : Option Expr := none) : GoalM Unit := do
Arith.internalize e parent?
AC.internalize e parent?
@[export lean_grind_internalize]
private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := withIncRecDepth do
@ -381,6 +378,7 @@ private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Opt
Otherwise, it will not be able to propagate that `a + 1 = 1` when `a = 0`
-/
internalizeTheories e parent?
Solvers.internalize e parent?
else
go
propagateEtaStruct e generation
@ -467,6 +465,7 @@ where
registerParent e arg
addCongrTable e
internalizeTheories e parent?
Solvers.internalize e parent?
propagateUp e
propagateBetaForNewApp e

View file

@ -9,7 +9,6 @@ public import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Proof
import Lean.Meta.Tactic.Grind.MatchCond
import Lean.Meta.Tactic.Grind.Arith.Inv
import Lean.Meta.Tactic.Grind.AC.Inv
namespace Lean.Meta.Grind
/-!
Debugging support code for checking basic invariants.
@ -126,7 +125,7 @@ public def checkInvariants (expensive := false) : GoalM Unit := do
if expensive then
checkPtrEqImpliesStructEq
Arith.checkInvariants
AC.checkInvariants
Solvers.checkInvariants
if expensive && grind.debug.proofs.get (← getOptions) then
checkProofs

View file

@ -19,7 +19,7 @@ private partial def solve (generation : Nat) : SearchM Bool := withIncRecDepth d
return false -- `splitNext` should have been configured to not create choice points
if (← getGoal).inconsistent then
return true
if (← intros' generation <||> assertAll <||> check <||> splitNext <||> ematch) then
if (← intros' generation <||> assertAll <||> check <||> Solvers.check <||> splitNext <||> ematch) then
solve generation
else
return false

View file

@ -109,7 +109,8 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
let thmMap := params.ematch
let casesTypes := params.casesTypes
let clean ← mkCleanState mvarId params
GoalM.run' { mvarId, ematch.thmMap := thmMap, split.casesTypes := casesTypes, clean } do
let sstates ← Solvers.mkInitialStates
GoalM.run' { mvarId, ematch.thmMap := thmMap, split.casesTypes := casesTypes, clean, sstates } do
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore btrueExpr (interpreted := false) (ctor := true) (generation := 0)

View file

@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Grind.Util
public import Init.Grind.PP
@ -14,7 +13,6 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.PP
public import Lean.Meta.Tactic.Grind.Arith.Linear.PP
public import Lean.Meta.Tactic.Grind.AC.PP
import Lean.PrettyPrinter
public section
namespace Lean.Meta.Grind

View file

@ -186,7 +186,7 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do
propagateCutsatDiseq lhs rhs
propagateCommRingDiseq lhs rhs
propagateLinarithDiseq lhs rhs
propagateACDiseq lhs rhs
Solvers.propagateDiseqs lhs rhs
let thms ← getExtTheorems α
if !thms.isEmpty then
/-

View file

@ -51,8 +51,8 @@ where
intros gen
else
break
if (← assertAll <||> check <||> ematch <||> lookahead <||> splitNext <||> Arith.Cutsat.mbtc
<||> Arith.Linear.mbtc <||> tryFallback) then
if (← assertAll <||> check <||> Solvers.check <||> ematch <||> lookahead <||> splitNext <||> Arith.Cutsat.mbtc
<||> Arith.Linear.mbtc <||> Solvers.mbtc <||> tryFallback) then
continue
return some (← getGoal) -- failed
return none -- solved

View file

@ -16,7 +16,6 @@ public import Lean.Meta.Tactic.Grind.AlphaShareCommon
public import Lean.Meta.Tactic.Grind.Attr
public import Lean.Meta.Tactic.Grind.ExtAttr
public import Lean.Meta.Tactic.Grind.Arith.Types
public import Lean.Meta.Tactic.Grind.AC.Types
public import Lean.Meta.Tactic.Grind.EMatchTheorem
meta import Lean.Parser.Do
import Lean.Meta.Match.MatchEqsExt
@ -379,6 +378,23 @@ private def expandReportIssueMacro (s : Syntax) : MacroM (TSyntax `doElem) := do
macro "reportIssue!" s:(interpolatedStr(term) <|> term) : doElem => do
expandReportIssueMacro s.raw
/--
Each E-node may have "solver terms" attached to them.
Each term is an element of the equivalence class that the
solver cares about. Each solver is responsible for marking the terms they care about.
The `grind` core propagates equalities and disequalities to the theory solvers
using these "marked" terms. The root of the equivalence class
contains a list of representatives sorted by solver id. Note that many E-nodes
do not have any solver terms attached to them.
"Solver terms" are referenced as "theory variables" in the SMT literature.
The SMT solver Z3 uses a similar representation.
-/
inductive SolverTerms where
| nil
| next (solverId : Nat) (e : Expr) (rest : SolverTerms)
deriving Inhabited, Repr
/--
Stores information for a node in the E-graph.
Each internalized expression `e` has an `ENode` associated with it.
@ -445,13 +461,8 @@ structure ENode where
to the linarith module. Its implementation is similar to the `offset?` field.
-/
linarith? : Option Expr := none
/--
The `ac?` field is used to propagate equalities from the `grind` congruence closure module
to the ac module. Its implementation is similar to the `offset?` field.
-/
ac? : Option Expr := none
-- Remark: we expect to have builtin support for offset constraints, cutsat, comm ring, linarith, and ac.
-- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3.
/-- Solver terms attached to this E-node. -/
sTerms : SolverTerms := .nil
deriving Inhabited, Repr
def ENode.isRoot (n : ENode) :=
@ -773,8 +784,6 @@ structure Goal where
split : Split.State := {}
/-- State of arithmetic procedures. -/
arith : Arith.State := {}
/-- State of the ac solver. -/
ac : AC.State := {}
/-- State of the clean name generator. -/
clean : Clean.State := {}
/-- Solver states. -/
@ -1171,14 +1180,10 @@ and invokes process `Arith.Cutsat.processNewDiseq e₁ e₂`
def propagateCutsatDiseq (lhs rhs : Expr) : GoalM Unit := do
let some lhs ← get? lhs | return ()
let some rhs ← get? rhs | return ()
-- Recall that core can take care of disequalities of the form `1≠2`.
unless isNum lhs && isNum rhs do
Arith.Cutsat.processNewDiseq lhs rhs
Arith.Cutsat.processNewDiseq lhs rhs
where
get? (a : Expr) : GoalM (Option Expr) := do
let root ← getRootENode a
if isNum root.self then
return some root.self
return root.cutsat?
/--
@ -1295,53 +1300,6 @@ def markAsLinarithTerm (e : Expr) : GoalM Unit := do
setENode root.self { root with linarith? := some e }
propagateLinarithDiseqs (← getParents root.self)
/--
Notifies the ac module that `a = b` where
`a` and `b` are terms that have been internalized by this module.
-/
@[extern "lean_process_ac_eq"] -- forward definition
opaque AC.processNewEq (a b : Expr) : GoalM Unit
/--
Notifies the ac module that `a ≠ b` where
`a` and `b` are terms that have been internalized by this module.
-/
@[extern "lean_process_ac_diseq"] -- forward definition
opaque AC.processNewDiseq (a b : Expr) : GoalM Unit
/--
Given `lhs` and `rhs` that are known to be disequal, checks whether
`lhs` and `rhs` have ac terms `e₁` and `e₂` attached to them,
and invokes process `AC.processNewDiseq e₁ e₂`
-/
def propagateACDiseq (lhs rhs : Expr) : GoalM Unit := do
let some lhs ← get? lhs | return ()
let some rhs ← get? rhs | return ()
AC.processNewDiseq lhs rhs
where
get? (a : Expr) : GoalM (Option Expr) := do
return (← getRootENode a).ac?
/--
Traverses disequalities in `parents`, and propagate the ones relevant to the
ac module.
-/
def propagateACDiseqs (parents : ParentSet) : GoalM Unit := do
forEachDiseq parents propagateACDiseq
/--
Marks `e` as a term of interest to the ac module.
If the root of `e`s equivalence class has already a term of interest,
a new equality is propagated to the ac module.
-/
def markAsACTerm (e : Expr) : GoalM Unit := do
let root ← getRootENode e
if let some e' := root.ac? then
AC.processNewEq e e'
else
setENode root.self { root with ac? := some e }
propagateACDiseqs (← getParents root.self)
/-- Returns `true` is `e` is the root of its congruence class. -/
def isCongrRoot (e : Expr) : GoalM Bool := do
return (← getENode e).isCongrRoot
@ -1659,29 +1617,52 @@ structure SolverExtension (σ : Type) where private mk ::
checkInv : GoalM Unit
deriving Inhabited
private builtin_initialize solverExtensionsRef : IO.Ref (Array (SolverExtension EnvExtensionState)) ← IO.mkRef #[]
private builtin_initialize solverExtensionsRef : IO.Ref (Array (SolverExtension SolverExtensionState)) ← IO.mkRef #[]
/--
Registers a new solver extension for `grind`.
Solver extensions can only be registered during initialization.
Reason: We do not use any synchronization primitive to access `solverExtensionsRef`.
-/
def registerSolverExtension {σ : Type}
(mkInitial : IO σ)
(internalize : Expr → (parent? : Option Expr) → GoalM Unit)
(newEq : Expr → Expr → GoalM Unit)
(newDiseq : Expr → Expr → GoalM Unit := fun _ _ => return ())
(mbtc : GoalM Bool := return false)
(check : GoalM Bool := return false)
(checkInv : GoalM Unit := return ()) : IO (SolverExtension σ) := do
def registerSolverExtension {σ : Type} (mkInitial : IO σ) : IO (SolverExtension σ) := do
unless (← initializing) do
throw (IO.userError "failed to register `grind` solver, extensions can only be registered during initialization")
let exts ← solverExtensionsRef.get
let id := exts.size
let ext : SolverExtension σ := { id, mkInitial, internalize, newEq, newDiseq, check, checkInv, mbtc }
let ext : SolverExtension σ := {
id, mkInitial
internalize := fun _ _ => return ()
newEq := fun _ _ => return ()
newDiseq := fun _ _ => return ()
check := fun _ _ => return false
checkInv := fun _ _ => return ()
mbtc := fun _ _ => return false
}
solverExtensionsRef.modify fun exts => exts.push (unsafe unsafeCast ext)
return ext
def mkInitialExtStates : IO (Array EnvExtensionState) := do
/--
Sets methods/handlers for solver extension `ext`.
Solver extension methods can only be registered during initialization.
Reason: We do not use any synchronization primitive to access `solverExtensionsRef`.
-/
def SolverExtension.setMethods (ext : SolverExtension σ)
(internalize : Expr → (parent? : Option Expr) → GoalM Unit := fun _ _ => return ())
(newEq : Expr → Expr → GoalM Unit := fun _ _ => return ())
(newDiseq : Expr → Expr → GoalM Unit := fun _ _ => return ())
(mbtc : GoalM Bool := return false)
(check : GoalM Bool := return false)
(checkInv : GoalM Unit := return ()) : IO Unit := do
unless (← initializing) do
throw (IO.userError "failed to register `grind` solver, extensions can only be registered during initialization")
unless ext.id < (← solverExtensionsRef.get).size do
throw (IO.userError "failed to register `grind` solver methods, invalid solver id")
solverExtensionsRef.modify fun exts => exts.modify ext.id fun s => { s with
internalize, newEq, newDiseq, mbtc, check, checkInv
}
/-- Returns initial state for registered solvers. -/
def Solvers.mkInitialStates : IO (Array SolverExtensionState) := do
let exts ← solverExtensionsRef.get
exts.mapM fun ext => ext.mkInitial
@ -1696,11 +1677,14 @@ private unsafe def SolverExtension.modifyStateImpl (ext : SolverExtension σ) (f
@[implemented_by SolverExtension.modifyStateImpl]
opaque SolverExtension.modifyState (ext : SolverExtension σ) (f : σσ) : GoalM Unit
private unsafe def SolverExtension.getStateImpl (ext : SolverExtension σ) : GoalM σ := do
return unsafeCast (← get).sstates[ext.id]!
private unsafe def SolverExtension.getStateCoreImpl (ext : SolverExtension σ) (goal : Goal) : IO σ :=
return unsafeCast goal.sstates[ext.id]!
@[implemented_by SolverExtension.getStateImpl]
opaque SolverExtension.getState (ext : SolverExtension σ) : GoalM σ
@[implemented_by SolverExtension.getStateCoreImpl]
opaque SolverExtension.getStateCore (ext : SolverExtension σ) (goal : Goal) : IO σ
def SolverExtension.getState (ext : SolverExtension σ) : GoalM σ := do
ext.getStateCore (← get)
/-- Internalizes given expression in all registered solvers. -/
def Solvers.internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
@ -1726,4 +1710,125 @@ def Solvers.mbtc : GoalM Bool := do
result := true
return result
/--
Given a new disequality `lhs ≠ rhs`, propagates it to relevant theories.
-/
def Solvers.propagateDiseqs (lhs rhs : Expr) : GoalM Unit := do
go (← getRootENode lhs).sTerms (← getRootENode rhs).sTerms
where
go (lhsTerms rhsTerms : SolverTerms) : GoalM Unit := do
match lhsTerms, rhsTerms with
| .nil, _ => return ()
| _, .nil => return ()
| .next id₁ lhs lhsTerms, .next id₂ rhs rhsTerms =>
if id₁ == id₂ then
(← solverExtensionsRef.get)[id₁]!.newDiseq lhs rhs
go lhsTerms rhsTerms
else if id₁ < id₂ then
go lhsTerms (.next id₂ rhs rhsTerms)
else
go (.next id₁ lhs lhsTerms) rhsTerms
private def propagateDiseqOf (id : Nat) (lhs rhs : Expr) : GoalM Unit := do
visitLhs (← getRootENode lhs).sTerms
where
visitLhs (sTerms : SolverTerms) : GoalM Unit := do
match sTerms with
| .nil => return ()
| .next id' e sTerms =>
if id == id' then
visitRhs e (← getRootENode rhs).sTerms
else if id < id' then
return ()
else
visitLhs sTerms
visitRhs (lhsTerm : Expr) (sTerms : SolverTerms) : GoalM Unit := do
match sTerms with
| .nil => return ()
| .next id' e sTerms =>
if id == id' then
let rhsTerm := e
(← solverExtensionsRef.get)[id]!.newDiseq lhsTerm rhsTerm
else if id < id' then
return ()
else
visitRhs lhsTerm sTerms
def isSameSolverTerms (a b : SolverTerms) : Bool :=
unsafe ptrEq a b
def SolverExtension.markTerm (ext : SolverExtension σ) (e : Expr) : GoalM Unit := do
let root ← getRootENode e
let id := ext.id
let rec go (sTerms : SolverTerms) : GoalM SolverTerms := do
match sTerms with
| .nil => return .next id e .nil
| .next id' e' sTerms' =>
if id == id' then
(← solverExtensionsRef.get)[id]!.newEq e e'
return sTerms
else if id < id' then
return .next id e sTerms
else
let sTermsNew ← go sTerms'
if isSameSolverTerms sTermsNew sTerms' then
return sTerms
else
return .next id' e' sTermsNew
let sTermsNew ← go root.sTerms
unless isSameSolverTerms sTermsNew root.sTerms do
setENode root.self { root with sTerms := sTermsNew }
forEachDiseq (← getParents root.self) (propagateDiseqOf id)
private inductive PendingSolverPropagationsData where
| nil
| eq (solverId : Nat) (lhs rhs : Expr) (rest : PendingSolverPropagationsData)
| diseqs (solverId : Nat) (ps : ParentSet) (rest : PendingSolverPropagationsData)
structure PendingSolverPropagations where private mk ::
private data : PendingSolverPropagationsData
def Solvers.mergeTerms (rhsRoot lhsRoot : ENode) : GoalM PendingSolverPropagations := do
let (sTerms, data) ← go rhsRoot.sTerms lhsRoot.sTerms
unless sTerms matches .nil do
-- We have to retrieve the node because other fields have been updated
let rhsRoot ← getENode rhsRoot.self
setENode rhsRoot.self { rhsRoot with sTerms }
return { data }
where
toPendingDiseqs (sTerms : SolverTerms) (ps : ParentSet) : PendingSolverPropagationsData :=
match sTerms with
| .nil => .nil
| .next id _ sTerms => .diseqs id ps (toPendingDiseqs sTerms ps)
go (rhsTerms : SolverTerms) (lhsTerms : SolverTerms) : GoalM (SolverTerms × PendingSolverPropagationsData) := do
match rhsTerms, lhsTerms with
| .nil, .nil => return (.nil, .nil)
| .nil, .next .. => return (lhsTerms, toPendingDiseqs lhsTerms (← getParents rhsRoot.self))
| .next .., .nil => return (rhsTerms, toPendingDiseqs rhsTerms (← getParents lhsRoot.self))
| .next id₁ rhs rhsTerms, .next id₂ lhs lhsTerms =>
if id₁ == id₂ then
let (s, p) ← go rhsTerms lhsTerms
return (.next id₁ rhs s, .eq id₁ lhs rhs p)
else if id₁ < id₂ then
let (s, p) ← go rhsTerms (.next id₂ lhs lhsTerms)
return (.next id₁ rhs s, .diseqs id₁ (← getParents lhsRoot.self) p)
else
let (s, p) ← go (.next id₁ rhs rhsTerms) lhsTerms
return (.next id₂ lhs s, .diseqs id₂ (← getParents rhsRoot.self) p)
def PendingSolverPropagations.propagate (p : PendingSolverPropagations) : GoalM Unit := do
go p.data
where
go (p : PendingSolverPropagationsData) : GoalM Unit := do
match p with
| .nil => return ()
| .eq solverId lhs rhs rest =>
(← solverExtensionsRef.get)[solverId]!.newEq lhs rhs
go rest
| .diseqs solverId parentSet rest =>
forEachDiseq parentSet (propagateDiseqOf solverId)
go rest
end Lean.Meta.Grind

View file

@ -1,3 +1,4 @@
// update me!
#include "util/options.h"
namespace lean {