feat: add addEdge to grind order (#10596)

This PR implements the function for adding new edges to the graph used
by `grind order`. The graph maintains the transitive closure of all
asserted constraints.
This commit is contained in:
Leonardo de Moura 2025-09-27 11:18:41 -07:00 committed by GitHub
parent fbfc7694a0
commit 0504e32bb7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 233 additions and 41 deletions

View file

@ -196,6 +196,20 @@ theorem le_eq_false_of_lt {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPre
have := Preorder.lt_irrefl a
contradiction
theorem lt_eq_false_of_lt {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α]
(a b : α) : a < b → (b < a) = False := by
simp; intro h₁ h₂
have := lt_trans h₁ h₂
have := Preorder.lt_irrefl a
contradiction
theorem lt_eq_false_of_le {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α]
(a b : α) : a ≤ b → (b < a) = False := by
simp; intro h₁ h₂
have := le_lt_trans h₁ h₂
have := Preorder.lt_irrefl a
contradiction
theorem le_eq_false_of_le_k {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α] [Ring α] [OrderedRing α]
(a b : α) (k₁ k₂ : Int) : (k₂ + k₁).blt' 0 → a ≤ b + k₁ → (b ≤ a + k₂) = False := by
intro h₁; simp; intro h₂ h₃

View file

@ -25,23 +25,108 @@ where
let p' ← getProof u p.w
go (← mkTrans p' p v)
/--
Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t.
it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`,
this function closes the current goal by constructing a proof of `False`.
-/
def setUnsat (u v : NodeId) (kuv : Weight) (huv : Expr) (kvu : Weight) : OrderM Unit := do
let hvu ← mkProofForPath v u
let u ← getExpr u
let v ← getExpr v
let h ← mkUnsatProof u v kuv huv kvu hvu
closeGoal h
/-- Sets the new shortest distance `k` between nodes `u` and `v`. -/
def setDist (u v : NodeId) (k : Weight) : OrderM Unit := do
modifyStruct fun s => { s with
targets := s.targets.modify u fun es => es.insert v k
sources := s.sources.modify v fun es => es.insert u k
}
def setProof (u v : NodeId) (p : ProofInfo) : OrderM Unit := do
modifyStruct fun s => { s with
proofs := s.proofs.modify u fun es => es.insert v p
}
@[inline] def forEachSourceOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do
(← getStruct).sources[u]!.forM f
@[inline] def forEachTargetOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do
(← getStruct).targets[u]!.forM f
/-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/
def isShorter (u v : NodeId) (k : Weight) : OrderM Bool := do
if let some k' ← getDist? u v then
return k < k'
else
return true
/-- Adds `p` to the list of things to be propagated. -/
def pushToPropagate (p : ToPropagate) : OrderM Unit :=
modifyStruct fun s => { s with propagate := p :: s.propagate }
/-
def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Int) : OrderM Unit := do
def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Weight) : OrderM Unit := do
let kuv ← mkProofForPath u v
let u ← getExpr u
let v ← getExpr v
pushEqTrue e <| mkPropagateEqTrueProof u v k kuv k'
let h ← mkPropagateEqTrueProof u v k kuv k'
pushEqTrue e h
private def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Int) : OrderM Unit := do
def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Weight) : OrderM Unit := do
let kuv ← mkProofForPath u v
let u ← getExpr u
let v ← getExpr v
pushEqFalse e <| mkPropagateEqFalseProof u v k kuv k'
let h ← mkPropagateEqFalseProof u v k kuv k'
pushEqFalse e h
/-- Propagates all pending constraints and equalities and resets to "to do" list. -/
private def propagatePending : OrderM Unit := do
let todo := (← getStruct).propagate
modifyStruct fun s => { s with propagate := [] }
for p in todo do
match p with
| .eqTrue e u v k k' => propagateEqTrue e u v k k'
| .eqFalse e u v k k' => propagateEqFalse e u v k k'
| .eq u v =>
let ue ← getExpr u
let ve ← getExpr v
unless (← isEqv ue ve) do
let huv ← mkProofForPath u v
let hvu ← mkProofForPath v u
let h ← mkEqProofOfLeOfLe ue ve huv hvu
pushEq ue ve h
def Cnstr.getWeight? (c : Cnstr α) : Option Weight :=
match c.kind with
| .le => some { k := c.k }
| .lt => some { k := c.k, strict := true }
| .eq => none
/--
Given `e` represented by constraint `c` (from `u` to `v`).
Checks whether `e = True` can be propagated using the path `u --(k)--> v`.
If it can, adds a new entry to propagation list.
-/
def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
let some k' := c.getWeight? | return false
if k ≤ k' then
pushToPropagate <| .eqTrue e u v k k'
return true
else
return false
/--
Given `e` represented by constraint `c` (from `v` to `u`).
Checks whether `e = False` can be propagated using the path `u --(k)--> v`.
If it can, adds a new entry to propagation list.
-/
def checkEqFalse (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
let some k' := c.getWeight? | return false
if (k + k').isNeg then
pushToPropagate <| .eqFalse e u v k k'
return true
return false
/--
Auxiliary function for implementing theory propagation.
@ -51,8 +136,7 @@ associated with `(u, v)` IF
- `e` is already assigned, or
- `f c e` returns true
-/
@[inline]
private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do
@[inline] def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do
if let some cs := (← getStruct).cnstrsOf.find? (u, v) then
let cs' ← cs.filterM fun (c, e) => do
if (← isEqTrue e <||> isEqFalse e) then
@ -61,6 +145,62 @@ private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM
return !(← f c e)
modifyStruct fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }
/-- Equality propagation. -/
def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
if !k.isZero then return ()
let some k' ← getDist? v u | return ()
if !k'.isZero then return ()
let ue ← getExpr u
let ve ← getExpr v
if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
if (← isEqv ue ve) then return ()
pushToPropagate <| .eq u v
/-- Finds constrains and equalities to be propagated. -/
def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do
updateCnstrsOf u v fun c e => return !(← checkEqTrue u v k c e)
updateCnstrsOf v u fun c e => return !(← checkEqFalse u v k c e)
checkEq u v k
/--
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
`w` is a node in the path from `u` to `v` such that `(← getProof? w v)` is `some`
-/
def updateIfShorter (u v : NodeId) (k : Weight) (w : NodeId) : OrderM Unit := do
if (← isShorter u v k) then
setDist u v k
setProof u v (← getProof w v)
checkToPropagate u v k
/--
Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
if no negative cycle was created, updates the shortest distance of affected
node pairs.
-/
def addEdge (u : NodeId) (v : NodeId) (k : Weight) (h : Expr) : OrderM Unit := do
if (← isInconsistent) then return ()
if let some k' ← getDist? v u then
if (k + k').isNeg then
setUnsat u v k h k'
return ()
if (← isShorter u v k) then
setDist u v k
setProof u v { w := u, k, proof := h }
checkToPropagate u v k
update
propagatePending
where
update : OrderM Unit := do
forEachTargetOf v fun j k₂ => do
/- Check whether new path: `u -(k)-> v -(k₂)-> j` is shorter -/
updateIfShorter u j (k+k₂) v
forEachSourceOf u fun i k₁ => do
/- Check whether new path: `i -(k₁)-> u -(k)-> v` is shorter -/
updateIfShorter i v (k₁+k) u
forEachTargetOf v fun j k₂ => do
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
updateIfShorter i j (k₁+k+k₂) v
def assertTrue (c : Cnstr NodeId) (p : Expr) : OrderM Unit := do
trace[grind.order.assert] "{p} = True: {← c.pp}"

View file

@ -57,4 +57,7 @@ def getCnstr? (e : Expr) : OrderM (Option (Cnstr NodeId)) :=
def isRing : OrderM Bool :=
return (← getStruct).ringId?.isSome
def isPartialOrder : OrderM Bool :=
return (← getStruct).isPartialInst?.isSome
end Lean.Meta.Grind.Order

View file

@ -16,6 +16,13 @@ def mkLePreorderPrefix (declName : Name) : OrderM Expr := do
let s ← getStruct
return mkApp3 (mkConst declName [s.u]) s.type s.leInst s.isPreorderInst
/--
Returns `declName α leInst isPartialInst`
-/
def mkLePartialPrefix (declName : Name) : OrderM Expr := do
let s ← getStruct
return mkApp3 (mkConst declName [s.u]) s.type s.leInst s.isPartialInst?.get!
/--
Returns `declName α leInst ltInst lawfulOrderLtInst`
-/
@ -127,8 +134,13 @@ public def mkPropagateEqTrueProof (u v : Expr) (k : Weight) (huv : Expr) (k' : W
/--
`u < v → (v ≤ u) = False
-/
def mkPropagateEqFalseProofCore (u v : Expr) (huv : Expr) : OrderM Expr := do
let h ← mkLeLtPreorderPrefix ``Grind.Order.le_eq_false_of_lt
def mkPropagateEqFalseProofCore (u v : Expr) (k : Weight) (huv : Expr) (k' : Weight) : OrderM Expr := do
let declName := match k'.strict, k.strict with
| false, false => unreachable!
| false, true => ``Grind.Order.le_eq_false_of_lt
| true, false => ``Grind.Order.lt_eq_false_of_le
| true, true => ``Grind.Order.lt_eq_false_of_lt
let h ← mkLeLtPreorderPrefix declName
return mkApp3 h u v huv
def mkPropagateEqFalseProofOffset (u v : Expr) (k : Weight) (huv : Expr) (k' : Weight) : OrderM Expr := do
@ -148,7 +160,7 @@ public def mkPropagateEqFalseProof (u v : Expr) (k : Weight) (huv : Expr) (k' :
if (← isRing) then
mkPropagateEqFalseProofOffset u v k huv k'
else
mkPropagateEqFalseProofCore u v huv
mkPropagateEqFalseProofCore u v k huv k'
def mkUnsatProofCore (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do
let h ← mkTransCoreProof u v u k₁.strict k₂.strict h₁ h₂
@ -170,10 +182,14 @@ Returns a proof of `False` using a negative cycle composed of
- `u --(k₁)--> v` with proof `h₁`
- `v --(k₂)--> u` with proof `h₂`
-/
def mkUnsatProof (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do
public def mkUnsatProof (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do
if (← isRing) then
mkUnsatProofOffset u v k₁ h₁ k₂ h₂
else
mkUnsatProofCore u v k₁ h₁ k₂ h₂
public def mkEqProofOfLeOfLe (u v : Expr) (h₁ : Expr) (h₂ : Expr) : OrderM Expr := do
let h ← mkLePartialPrefix ``Grind.Order.eq_of_le_of_le
return mkApp4 h u v h₁ h₂
end Lean.Meta.Grind.Order

View file

@ -39,33 +39,6 @@ structure Weight where
strict := false
deriving Inhabited
def Weight.compare (a b : Weight) : Ordering :=
if a.k < b.k then
.lt
else if b.k > a.k then
.gt
else if a.strict == b.strict then
.eq
else if !a.strict && b.strict then
.lt
else
.gt
instance : Ord Weight where
compare := Weight.compare
instance : LE Weight where
le a b := compare a b ≠ .gt
instance : LT Weight where
lt a b := compare a b = .lt
instance : DecidableLE Weight :=
fun a b => inferInstanceAs (Decidable (compare a b ≠ .gt))
instance : DecidableLT Weight :=
fun a b => inferInstanceAs (Decidable (compare a b = .lt))
/-- Auxiliary structure used for proof extraction. -/
structure ProofInfo where
w : NodeId
@ -82,8 +55,8 @@ Thus, we store the information to be propagated into a list.
See field `propagate` in `State`.
-/
inductive ToPropagate where
| eqTrue (e : Expr) (u v : NodeId) (k k' : Int)
| eqFalse (e : Expr) (u v : NodeId) (k k' : Int)
| eqTrue (e : Expr) (u v : NodeId) (k k' : Weight)
| eqFalse (e : Expr) (u v : NodeId) (k k' : Weight)
| eq (u v : NodeId)
deriving Inhabited

View file

@ -7,9 +7,10 @@ module
prelude
public import Lean.Meta.Tactic.Grind.Order.OrderM
import Lean.Meta.Tactic.Grind.Arith.Util
public section
namespace Lean.Meta.Grind.Order
public def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do
def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do
let u ← getExpr c.u
let v ← getExpr c.v
let op := match c.kind with
@ -21,4 +22,49 @@ public def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do
else
return m!"{Arith.quoteIfArithTerm u} {op} {Arith.quoteIfArithTerm v}"
def Weight.compare (a b : Weight) : Ordering :=
if a.k < b.k then
.lt
else if b.k > a.k then
.gt
else if a.strict == b.strict then
.eq
else if a.strict && !b.strict then
/-
**Note**: Recall that we view a constraint of the
form `x < y + k` as `x ≤ y + (k - ε)` where `ε` is
an "infinitesimal" positive value.
Thus, `k - ε < k`
-/
.lt
else
.gt
instance : Ord Weight where
compare := Weight.compare
instance : LE Weight where
le a b := compare a b ≠ .gt
instance : LT Weight where
lt a b := compare a b = .lt
instance : DecidableLE Weight :=
fun a b => inferInstanceAs (Decidable (compare a b ≠ .gt))
instance : DecidableLT Weight :=
fun a b => inferInstanceAs (Decidable (compare a b = .lt))
def Weight.add (a b : Weight) : Weight :=
{ k := a.k + b.k, strict := a.strict || b.strict }
instance : Add Weight where
add := Weight.add
def Weight.isNeg (a : Weight) : Bool :=
a.k < 0 || (a.k == 0 && a.strict)
def Weight.isZero (a : Weight) : Bool :=
a.k == 0 && !a.strict
end Lean.Meta.Grind.Order