diff --git a/src/Lean/Meta/Tactic/Grind/AC.lean b/src/Lean/Meta/Tactic/Grind/AC.lean index b232708d43..13a160eed3 100644 --- a/src/Lean/Meta/Tactic/Grind/AC.lean +++ b/src/Lean/Meta/Tactic/Grind/AC.lean @@ -16,8 +16,9 @@ public import Lean.Meta.Tactic.Grind.AC.DenoteExpr public import Lean.Meta.Tactic.Grind.AC.ToExpr public import Lean.Meta.Tactic.Grind.AC.VarRename public import Lean.Meta.Tactic.Grind.AC.PP +public import Lean.Meta.Tactic.Grind.AC.Inv public section -namespace Lean +namespace Lean.Meta.Grind.AC builtin_initialize registerTraceClass `grind.ac builtin_initialize registerTraceClass `grind.ac.assert builtin_initialize registerTraceClass `grind.ac.internalize @@ -29,4 +30,13 @@ builtin_initialize registerTraceClass `grind.debug.ac.check builtin_initialize registerTraceClass `grind.debug.ac.queue builtin_initialize registerTraceClass `grind.debug.ac.superpose builtin_initialize registerTraceClass `grind.debug.ac.eq -end Lean + +builtin_initialize + acExt.setMethods + (internalize := AC.internalize) + (newEq := AC.processNewEq) + (newDiseq := AC.processNewDiseq) + (check := AC.check) + (checkInv := AC.checkInvariants) + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Eq.lean b/src/Lean/Meta/Tactic/Grind/AC/Eq.lean index 5febc23087..4eb6058c1f 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Eq.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Eq.lean @@ -360,8 +360,7 @@ private def EqCnstr.assert (c : EqCnstr) : ACM Unit := do else c.addToQueue -@[export lean_process_ac_eq] -def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do +def processNewEq(a b : Expr) : GoalM Unit := withExprs a b do let ea ← asACExpr a let lhs ← norm ea let eb ← asACExpr b @@ -369,8 +368,7 @@ def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do let c ← mkEqCnstr lhs rhs (.core a b ea eb) c.assert -@[export lean_process_ac_diseq] -def processNewDiseqImpl (a b : Expr) : GoalM Unit := withExprs a b do +def processNewDiseq (a b : Expr) : GoalM Unit := withExprs a b do let ea ← asACExpr a let lhs ← norm ea let eb ← asACExpr b diff --git a/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean b/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean index 2b69690a76..d6fca3b7f8 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean @@ -22,8 +22,7 @@ partial def reify (e : Expr) : ACM Grind.AC.Expr := do else return .var (← mkVar e) -@[export lean_grind_ac_internalize] -def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do +def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do unless (← getConfig).ac do return () unless e.isApp && e.appFn!.isApp do return () let op := e.appFn!.appFn! @@ -39,6 +38,6 @@ def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do } trace[grind.ac.internalize] "[{id}] {← e'.denoteExpr}" addTermOpId e - markAsACTerm e + acExt.markTerm e end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/PP.lean b/src/Lean/Meta/Tactic/Grind/AC/PP.lean index 2ff59c8cb6..41ad7b0b1e 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/PP.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/PP.lean @@ -53,7 +53,7 @@ private def ppStruct? : M (Option MessageData) := do def pp? (goal : Goal) : MetaM (Option MessageData) := do let mut msgs := #[] - for struct in goal.ac.structs do + for struct in (← acExt.getStateCore goal).structs do let some msg ← ppStruct? |>.run' struct | pure () msgs := msgs.push msg if msgs.isEmpty then diff --git a/src/Lean/Meta/Tactic/Grind/AC/Types.lean b/src/Lean/Meta/Tactic/Grind/AC/Types.lean index 781abb308e..994c21652d 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Types.lean @@ -8,9 +8,7 @@ prelude public import Init.Core public import Init.Grind.AC public import Std.Data.HashMap -public import Lean.Expr -public import Lean.Data.PersistentArray -public import Lean.Meta.Tactic.Grind.ExprPtr +public import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.AC.Seq public section namespace Lean.Meta.Grind.AC @@ -98,7 +96,7 @@ structure Struct where /-- Mapping from Lean expressions to their representations as `AC.Expr` -/ denote : PHashMap ExprPtr AC.Expr := {} /-- `denoteEntries` is `denote` as a `PArray` for deterministic traversal. -/ - denoteEntries : PArray (Expr × AC.Expr) := {} + denoteEntries : PArray (Expr × AC.Expr) := {} /-- Equations to process. -/ queue : Queue := {} /-- Processed equations. -/ @@ -130,4 +128,6 @@ structure State where steps := 0 deriving Inhabited +builtin_initialize acExt : SolverExtension State ← registerSolverExtension (return {}) + end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Util.lean b/src/Lean/Meta/Tactic/Grind/AC/Util.lean index 224b62823c..8b26259c4e 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Util.lean @@ -5,7 +5,7 @@ Authors: Leonardo de Moura -/ module prelude -public import Lean.Meta.Tactic.Grind.Types +public import Lean.Meta.Tactic.Grind.AC.Types public import Lean.Meta.Tactic.Grind.ProveEq public import Lean.Meta.Tactic.Grind.SynthInstance public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId @@ -15,10 +15,10 @@ namespace Lean.Meta.Grind.AC open Lean.Grind def get' : GoalM State := do - return (← get).ac + acExt.getState @[inline] def modify' (f : State → State) : GoalM Unit := do - modify fun s => { s with ac := f s.ac } + acExt.modifyState f def checkMaxSteps : GoalM Bool := do return (← get').steps >= (← getConfig).acSteps @@ -115,7 +115,7 @@ def mkVar (e : Expr) : ACM AC.Var := do varMap := s.varMap.insert { expr := e } var } addTermOpId e - markAsACTerm e + acExt.markTerm e return var def getOpId? (op : Expr) : GoalM (Option Nat) := do diff --git a/src/Lean/Meta/Tactic/Grind/Check.lean b/src/Lean/Meta/Tactic/Grind/Check.lean index 7fa00485b5..618744c669 100644 --- a/src/Lean/Meta/Tactic/Grind/Check.lean +++ b/src/Lean/Meta/Tactic/Grind/Check.lean @@ -7,12 +7,11 @@ module prelude public import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Arith.Main -import Lean.Meta.Tactic.Grind.AC.Eq namespace Lean.Meta.Grind /-- Checks whether satellite solvers can make progress (e.g., detect unsatisfiability, propagate equations, etc) -/ public def check : GoalM Bool := do - Arith.check <||> AC.check + Arith.check namespace Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index 284aaefb08..170a6184a3 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -215,31 +215,6 @@ def propagateLinarith : PendingTheoryPropagation → GoalM Unit | .diseqs ps => propagateLinarithDiseqs ps | _ => return () -/-- -Helper function for combining `ENode.ac?` fields and detecting what needs to be -propagated to the ac module. --/ -private def checkACEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do - match lhsRoot.ac? with - | some lhs => - if let some rhs := rhsRoot.ac? then - return .eq lhs rhs - else - -- We have to retrieve the node because other fields have been updated - let rhsRoot ← getENode rhsRoot.self - setENode rhsRoot.self { rhsRoot with ac? := lhs } - return .diseqs (← getParents rhsRoot.self) - | none => - if rhsRoot.ac?.isSome then - return .diseqs (← getParents lhsRoot.self) - else - return .none - -def propagateAC : PendingTheoryPropagation → GoalM Unit - | .eq lhs rhs => AC.processNewEq lhs rhs - | .diseqs ps => propagateACDiseqs ps - | _ => return () - /-- Tries to apply beta-reduction using the parent applications of the functions in `fns` with the lambda expressions in `lams`. @@ -374,7 +349,7 @@ where let cutsatTodo ← checkCutsatEq rhsRoot lhsRoot let ringTodo ← checkCommRingEq rhsRoot lhsRoot let linarithTodo ← checkLinarithEq rhsRoot lhsRoot - let ACTodo ← checkACEq rhsRoot lhsRoot + let todo ← Solvers.mergeTerms rhsRoot lhsRoot resetParentsOf lhsRoot.self copyParentsTo parents rhsNode.root unless (← isInconsistent) do @@ -389,7 +364,7 @@ where propagateCutsat cutsatTodo propagateCommRing ringTodo propagateLinarith linarithTodo - propagateAC ACTodo + todo.propagate updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do let isFalseRoot ← isFalseExpr rootNew traverseEqc lhs fun n => do diff --git a/src/Lean/Meta/Tactic/Grind/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Internalize.lean index 6b32061f97..d76c53ede2 100644 --- a/src/Lean/Meta/Tactic/Grind/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Internalize.lean @@ -21,8 +21,6 @@ import Lean.Meta.Tactic.Grind.MarkNestedSubsingletons public section namespace Lean.Meta.Grind -@[extern "lean_grind_ac_internalize"] -- forward definition -opaque AC.internalize (e : Expr) (parent? : Option Expr) : GoalM Unit @[extern "lean_grind_arith_internalize"] -- forward definition opaque Arith.internalize (e : Expr) (parent? : Option Expr) : GoalM Unit @@ -366,7 +364,6 @@ private def tryEta (e : Expr) (generation : Nat) : GoalM Unit := do private def internalizeTheories (e : Expr) (parent? : Option Expr := none) : GoalM Unit := do Arith.internalize e parent? - AC.internalize e parent? @[export lean_grind_internalize] private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := withIncRecDepth do @@ -381,6 +378,7 @@ private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Opt Otherwise, it will not be able to propagate that `a + 1 = 1` when `a = 0` -/ internalizeTheories e parent? + Solvers.internalize e parent? else go propagateEtaStruct e generation @@ -467,6 +465,7 @@ where registerParent e arg addCongrTable e internalizeTheories e parent? + Solvers.internalize e parent? propagateUp e propagateBetaForNewApp e diff --git a/src/Lean/Meta/Tactic/Grind/Inv.lean b/src/Lean/Meta/Tactic/Grind/Inv.lean index e599865af3..15cfa8f071 100644 --- a/src/Lean/Meta/Tactic/Grind/Inv.lean +++ b/src/Lean/Meta/Tactic/Grind/Inv.lean @@ -9,7 +9,6 @@ public import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Proof import Lean.Meta.Tactic.Grind.MatchCond import Lean.Meta.Tactic.Grind.Arith.Inv -import Lean.Meta.Tactic.Grind.AC.Inv namespace Lean.Meta.Grind /-! Debugging support code for checking basic invariants. @@ -126,7 +125,7 @@ public def checkInvariants (expensive := false) : GoalM Unit := do if expensive then checkPtrEqImpliesStructEq Arith.checkInvariants - AC.checkInvariants + Solvers.checkInvariants if expensive && grind.debug.proofs.get (← getOptions) then checkProofs diff --git a/src/Lean/Meta/Tactic/Grind/Lookahead.lean b/src/Lean/Meta/Tactic/Grind/Lookahead.lean index e9b757b57f..6d659e9a77 100644 --- a/src/Lean/Meta/Tactic/Grind/Lookahead.lean +++ b/src/Lean/Meta/Tactic/Grind/Lookahead.lean @@ -19,7 +19,7 @@ private partial def solve (generation : Nat) : SearchM Bool := withIncRecDepth d return false -- `splitNext` should have been configured to not create choice points if (← getGoal).inconsistent then return true - if (← intros' generation <||> assertAll <||> check <||> splitNext <||> ematch) then + if (← intros' generation <||> assertAll <||> check <||> Solvers.check <||> splitNext <||> ematch) then solve generation else return false diff --git a/src/Lean/Meta/Tactic/Grind/Main.lean b/src/Lean/Meta/Tactic/Grind/Main.lean index 054a6e7c1d..46cf108ddf 100644 --- a/src/Lean/Meta/Tactic/Grind/Main.lean +++ b/src/Lean/Meta/Tactic/Grind/Main.lean @@ -109,7 +109,8 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do let thmMap := params.ematch let casesTypes := params.casesTypes let clean ← mkCleanState mvarId params - GoalM.run' { mvarId, ematch.thmMap := thmMap, split.casesTypes := casesTypes, clean } do + let sstates ← Solvers.mkInitialStates + GoalM.run' { mvarId, ematch.thmMap := thmMap, split.casesTypes := casesTypes, clean, sstates } do mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0) mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0) mkENodeCore btrueExpr (interpreted := false) (ctor := true) (generation := 0) diff --git a/src/Lean/Meta/Tactic/Grind/PP.lean b/src/Lean/Meta/Tactic/Grind/PP.lean index 6fb4490b07..e3d6b4fc61 100644 --- a/src/Lean/Meta/Tactic/Grind/PP.lean +++ b/src/Lean/Meta/Tactic/Grind/PP.lean @@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Grind.Util public import Init.Grind.PP @@ -14,7 +13,6 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.PP public import Lean.Meta.Tactic.Grind.Arith.Linear.PP public import Lean.Meta.Tactic.Grind.AC.PP import Lean.PrettyPrinter - public section namespace Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Propagate.lean b/src/Lean/Meta/Tactic/Grind/Propagate.lean index 35d3a0b7fc..56dabcfcae 100644 --- a/src/Lean/Meta/Tactic/Grind/Propagate.lean +++ b/src/Lean/Meta/Tactic/Grind/Propagate.lean @@ -186,7 +186,7 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do propagateCutsatDiseq lhs rhs propagateCommRingDiseq lhs rhs propagateLinarithDiseq lhs rhs - propagateACDiseq lhs rhs + Solvers.propagateDiseqs lhs rhs let thms ← getExtTheorems α if !thms.isEmpty then /- diff --git a/src/Lean/Meta/Tactic/Grind/Solve.lean b/src/Lean/Meta/Tactic/Grind/Solve.lean index 7427143fec..c59f78a476 100644 --- a/src/Lean/Meta/Tactic/Grind/Solve.lean +++ b/src/Lean/Meta/Tactic/Grind/Solve.lean @@ -51,8 +51,8 @@ where intros gen else break - if (← assertAll <||> check <||> ematch <||> lookahead <||> splitNext <||> Arith.Cutsat.mbtc - <||> Arith.Linear.mbtc <||> tryFallback) then + if (← assertAll <||> check <||> Solvers.check <||> ematch <||> lookahead <||> splitNext <||> Arith.Cutsat.mbtc + <||> Arith.Linear.mbtc <||> Solvers.mbtc <||> tryFallback) then continue return some (← getGoal) -- failed return none -- solved diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index a3b28b30f1..0ac9bef6b8 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -16,7 +16,6 @@ public import Lean.Meta.Tactic.Grind.AlphaShareCommon public import Lean.Meta.Tactic.Grind.Attr public import Lean.Meta.Tactic.Grind.ExtAttr public import Lean.Meta.Tactic.Grind.Arith.Types -public import Lean.Meta.Tactic.Grind.AC.Types public import Lean.Meta.Tactic.Grind.EMatchTheorem meta import Lean.Parser.Do import Lean.Meta.Match.MatchEqsExt @@ -379,6 +378,23 @@ private def expandReportIssueMacro (s : Syntax) : MacroM (TSyntax `doElem) := do macro "reportIssue!" s:(interpolatedStr(term) <|> term) : doElem => do expandReportIssueMacro s.raw +/-- +Each E-node may have "solver terms" attached to them. +Each term is an element of the equivalence class that the +solver cares about. Each solver is responsible for marking the terms they care about. +The `grind` core propagates equalities and disequalities to the theory solvers +using these "marked" terms. The root of the equivalence class +contains a list of representatives sorted by solver id. Note that many E-nodes +do not have any solver terms attached to them. + +"Solver terms" are referenced as "theory variables" in the SMT literature. +The SMT solver Z3 uses a similar representation. +-/ +inductive SolverTerms where + | nil + | next (solverId : Nat) (e : Expr) (rest : SolverTerms) + deriving Inhabited, Repr + /-- Stores information for a node in the E-graph. Each internalized expression `e` has an `ENode` associated with it. @@ -445,13 +461,8 @@ structure ENode where to the linarith module. Its implementation is similar to the `offset?` field. -/ linarith? : Option Expr := none - /-- - The `ac?` field is used to propagate equalities from the `grind` congruence closure module - to the ac module. Its implementation is similar to the `offset?` field. - -/ - ac? : Option Expr := none - -- Remark: we expect to have builtin support for offset constraints, cutsat, comm ring, linarith, and ac. - -- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3. + /-- Solver terms attached to this E-node. -/ + sTerms : SolverTerms := .nil deriving Inhabited, Repr def ENode.isRoot (n : ENode) := @@ -773,8 +784,6 @@ structure Goal where split : Split.State := {} /-- State of arithmetic procedures. -/ arith : Arith.State := {} - /-- State of the ac solver. -/ - ac : AC.State := {} /-- State of the clean name generator. -/ clean : Clean.State := {} /-- Solver states. -/ @@ -1171,14 +1180,10 @@ and invokes process `Arith.Cutsat.processNewDiseq e₁ e₂` def propagateCutsatDiseq (lhs rhs : Expr) : GoalM Unit := do let some lhs ← get? lhs | return () let some rhs ← get? rhs | return () - -- Recall that core can take care of disequalities of the form `1≠2`. - unless isNum lhs && isNum rhs do - Arith.Cutsat.processNewDiseq lhs rhs + Arith.Cutsat.processNewDiseq lhs rhs where get? (a : Expr) : GoalM (Option Expr) := do let root ← getRootENode a - if isNum root.self then - return some root.self return root.cutsat? /-- @@ -1295,53 +1300,6 @@ def markAsLinarithTerm (e : Expr) : GoalM Unit := do setENode root.self { root with linarith? := some e } propagateLinarithDiseqs (← getParents root.self) -/-- -Notifies the ac module that `a = b` where -`a` and `b` are terms that have been internalized by this module. --/ -@[extern "lean_process_ac_eq"] -- forward definition -opaque AC.processNewEq (a b : Expr) : GoalM Unit - -/-- -Notifies the ac module that `a ≠ b` where -`a` and `b` are terms that have been internalized by this module. --/ -@[extern "lean_process_ac_diseq"] -- forward definition -opaque AC.processNewDiseq (a b : Expr) : GoalM Unit - -/-- -Given `lhs` and `rhs` that are known to be disequal, checks whether -`lhs` and `rhs` have ac terms `e₁` and `e₂` attached to them, -and invokes process `AC.processNewDiseq e₁ e₂` --/ -def propagateACDiseq (lhs rhs : Expr) : GoalM Unit := do - let some lhs ← get? lhs | return () - let some rhs ← get? rhs | return () - AC.processNewDiseq lhs rhs -where - get? (a : Expr) : GoalM (Option Expr) := do - return (← getRootENode a).ac? - -/-- -Traverses disequalities in `parents`, and propagate the ones relevant to the -ac module. --/ -def propagateACDiseqs (parents : ParentSet) : GoalM Unit := do - forEachDiseq parents propagateACDiseq - -/-- -Marks `e` as a term of interest to the ac module. -If the root of `e`s equivalence class has already a term of interest, -a new equality is propagated to the ac module. --/ -def markAsACTerm (e : Expr) : GoalM Unit := do - let root ← getRootENode e - if let some e' := root.ac? then - AC.processNewEq e e' - else - setENode root.self { root with ac? := some e } - propagateACDiseqs (← getParents root.self) - /-- Returns `true` is `e` is the root of its congruence class. -/ def isCongrRoot (e : Expr) : GoalM Bool := do return (← getENode e).isCongrRoot @@ -1659,29 +1617,52 @@ structure SolverExtension (σ : Type) where private mk :: checkInv : GoalM Unit deriving Inhabited -private builtin_initialize solverExtensionsRef : IO.Ref (Array (SolverExtension EnvExtensionState)) ← IO.mkRef #[] +private builtin_initialize solverExtensionsRef : IO.Ref (Array (SolverExtension SolverExtensionState)) ← IO.mkRef #[] /-- +Registers a new solver extension for `grind`. Solver extensions can only be registered during initialization. Reason: We do not use any synchronization primitive to access `solverExtensionsRef`. -/ -def registerSolverExtension {σ : Type} - (mkInitial : IO σ) - (internalize : Expr → (parent? : Option Expr) → GoalM Unit) - (newEq : Expr → Expr → GoalM Unit) - (newDiseq : Expr → Expr → GoalM Unit := fun _ _ => return ()) - (mbtc : GoalM Bool := return false) - (check : GoalM Bool := return false) - (checkInv : GoalM Unit := return ()) : IO (SolverExtension σ) := do +def registerSolverExtension {σ : Type} (mkInitial : IO σ) : IO (SolverExtension σ) := do unless (← initializing) do throw (IO.userError "failed to register `grind` solver, extensions can only be registered during initialization") let exts ← solverExtensionsRef.get let id := exts.size - let ext : SolverExtension σ := { id, mkInitial, internalize, newEq, newDiseq, check, checkInv, mbtc } + let ext : SolverExtension σ := { + id, mkInitial + internalize := fun _ _ => return () + newEq := fun _ _ => return () + newDiseq := fun _ _ => return () + check := fun _ _ => return false + checkInv := fun _ _ => return () + mbtc := fun _ _ => return false + } solverExtensionsRef.modify fun exts => exts.push (unsafe unsafeCast ext) return ext -def mkInitialExtStates : IO (Array EnvExtensionState) := do +/-- +Sets methods/handlers for solver extension `ext`. +Solver extension methods can only be registered during initialization. +Reason: We do not use any synchronization primitive to access `solverExtensionsRef`. +-/ +def SolverExtension.setMethods (ext : SolverExtension σ) + (internalize : Expr → (parent? : Option Expr) → GoalM Unit := fun _ _ => return ()) + (newEq : Expr → Expr → GoalM Unit := fun _ _ => return ()) + (newDiseq : Expr → Expr → GoalM Unit := fun _ _ => return ()) + (mbtc : GoalM Bool := return false) + (check : GoalM Bool := return false) + (checkInv : GoalM Unit := return ()) : IO Unit := do + unless (← initializing) do + throw (IO.userError "failed to register `grind` solver, extensions can only be registered during initialization") + unless ext.id < (← solverExtensionsRef.get).size do + throw (IO.userError "failed to register `grind` solver methods, invalid solver id") + solverExtensionsRef.modify fun exts => exts.modify ext.id fun s => { s with + internalize, newEq, newDiseq, mbtc, check, checkInv + } + +/-- Returns initial state for registered solvers. -/ +def Solvers.mkInitialStates : IO (Array SolverExtensionState) := do let exts ← solverExtensionsRef.get exts.mapM fun ext => ext.mkInitial @@ -1696,11 +1677,14 @@ private unsafe def SolverExtension.modifyStateImpl (ext : SolverExtension σ) (f @[implemented_by SolverExtension.modifyStateImpl] opaque SolverExtension.modifyState (ext : SolverExtension σ) (f : σ → σ) : GoalM Unit -private unsafe def SolverExtension.getStateImpl (ext : SolverExtension σ) : GoalM σ := do - return unsafeCast (← get).sstates[ext.id]! +private unsafe def SolverExtension.getStateCoreImpl (ext : SolverExtension σ) (goal : Goal) : IO σ := + return unsafeCast goal.sstates[ext.id]! -@[implemented_by SolverExtension.getStateImpl] -opaque SolverExtension.getState (ext : SolverExtension σ) : GoalM σ +@[implemented_by SolverExtension.getStateCoreImpl] +opaque SolverExtension.getStateCore (ext : SolverExtension σ) (goal : Goal) : IO σ + +def SolverExtension.getState (ext : SolverExtension σ) : GoalM σ := do + ext.getStateCore (← get) /-- Internalizes given expression in all registered solvers. -/ def Solvers.internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do @@ -1726,4 +1710,125 @@ def Solvers.mbtc : GoalM Bool := do result := true return result +/-- +Given a new disequality `lhs ≠ rhs`, propagates it to relevant theories. +-/ +def Solvers.propagateDiseqs (lhs rhs : Expr) : GoalM Unit := do + go (← getRootENode lhs).sTerms (← getRootENode rhs).sTerms +where + go (lhsTerms rhsTerms : SolverTerms) : GoalM Unit := do + match lhsTerms, rhsTerms with + | .nil, _ => return () + | _, .nil => return () + | .next id₁ lhs lhsTerms, .next id₂ rhs rhsTerms => + if id₁ == id₂ then + (← solverExtensionsRef.get)[id₁]!.newDiseq lhs rhs + go lhsTerms rhsTerms + else if id₁ < id₂ then + go lhsTerms (.next id₂ rhs rhsTerms) + else + go (.next id₁ lhs lhsTerms) rhsTerms + +private def propagateDiseqOf (id : Nat) (lhs rhs : Expr) : GoalM Unit := do + visitLhs (← getRootENode lhs).sTerms +where + visitLhs (sTerms : SolverTerms) : GoalM Unit := do + match sTerms with + | .nil => return () + | .next id' e sTerms => + if id == id' then + visitRhs e (← getRootENode rhs).sTerms + else if id < id' then + return () + else + visitLhs sTerms + + visitRhs (lhsTerm : Expr) (sTerms : SolverTerms) : GoalM Unit := do + match sTerms with + | .nil => return () + | .next id' e sTerms => + if id == id' then + let rhsTerm := e + (← solverExtensionsRef.get)[id]!.newDiseq lhsTerm rhsTerm + else if id < id' then + return () + else + visitRhs lhsTerm sTerms + +def isSameSolverTerms (a b : SolverTerms) : Bool := + unsafe ptrEq a b + +def SolverExtension.markTerm (ext : SolverExtension σ) (e : Expr) : GoalM Unit := do + let root ← getRootENode e + let id := ext.id + let rec go (sTerms : SolverTerms) : GoalM SolverTerms := do + match sTerms with + | .nil => return .next id e .nil + | .next id' e' sTerms' => + if id == id' then + (← solverExtensionsRef.get)[id]!.newEq e e' + return sTerms + else if id < id' then + return .next id e sTerms + else + let sTermsNew ← go sTerms' + if isSameSolverTerms sTermsNew sTerms' then + return sTerms + else + return .next id' e' sTermsNew + let sTermsNew ← go root.sTerms + unless isSameSolverTerms sTermsNew root.sTerms do + setENode root.self { root with sTerms := sTermsNew } + forEachDiseq (← getParents root.self) (propagateDiseqOf id) + +private inductive PendingSolverPropagationsData where + | nil + | eq (solverId : Nat) (lhs rhs : Expr) (rest : PendingSolverPropagationsData) + | diseqs (solverId : Nat) (ps : ParentSet) (rest : PendingSolverPropagationsData) + +structure PendingSolverPropagations where private mk :: + private data : PendingSolverPropagationsData + +def Solvers.mergeTerms (rhsRoot lhsRoot : ENode) : GoalM PendingSolverPropagations := do + let (sTerms, data) ← go rhsRoot.sTerms lhsRoot.sTerms + unless sTerms matches .nil do + -- We have to retrieve the node because other fields have been updated + let rhsRoot ← getENode rhsRoot.self + setENode rhsRoot.self { rhsRoot with sTerms } + return { data } +where + toPendingDiseqs (sTerms : SolverTerms) (ps : ParentSet) : PendingSolverPropagationsData := + match sTerms with + | .nil => .nil + | .next id _ sTerms => .diseqs id ps (toPendingDiseqs sTerms ps) + + go (rhsTerms : SolverTerms) (lhsTerms : SolverTerms) : GoalM (SolverTerms × PendingSolverPropagationsData) := do + match rhsTerms, lhsTerms with + | .nil, .nil => return (.nil, .nil) + | .nil, .next .. => return (lhsTerms, toPendingDiseqs lhsTerms (← getParents rhsRoot.self)) + | .next .., .nil => return (rhsTerms, toPendingDiseqs rhsTerms (← getParents lhsRoot.self)) + | .next id₁ rhs rhsTerms, .next id₂ lhs lhsTerms => + if id₁ == id₂ then + let (s, p) ← go rhsTerms lhsTerms + return (.next id₁ rhs s, .eq id₁ lhs rhs p) + else if id₁ < id₂ then + let (s, p) ← go rhsTerms (.next id₂ lhs lhsTerms) + return (.next id₁ rhs s, .diseqs id₁ (← getParents lhsRoot.self) p) + else + let (s, p) ← go (.next id₁ rhs rhsTerms) lhsTerms + return (.next id₂ lhs s, .diseqs id₂ (← getParents rhsRoot.self) p) + +def PendingSolverPropagations.propagate (p : PendingSolverPropagations) : GoalM Unit := do + go p.data +where + go (p : PendingSolverPropagationsData) : GoalM Unit := do + match p with + | .nil => return () + | .eq solverId lhs rhs rest => + (← solverExtensionsRef.get)[solverId]!.newEq lhs rhs + go rest + | .diseqs solverId parentSet rest => + forEachDiseq parentSet (propagateDiseqOf solverId) + go rest + end Lean.Meta.Grind diff --git a/stage0/src/stdlib_flags.h b/stage0/src/stdlib_flags.h index 79a0e58edd..ad491b0de1 100644 --- a/stage0/src/stdlib_flags.h +++ b/stage0/src/stdlib_flags.h @@ -1,3 +1,4 @@ +// update me! #include "util/options.h" namespace lean {