feat: add addEdge to grind order (#10596)
This PR implements the function for adding new edges to the graph used by `grind order`. The graph maintains the transitive closure of all asserted constraints.
This commit is contained in:
parent
fbfc7694a0
commit
0504e32bb7
6 changed files with 233 additions and 41 deletions
|
|
@ -196,6 +196,20 @@ theorem le_eq_false_of_lt {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPre
|
|||
have := Preorder.lt_irrefl a
|
||||
contradiction
|
||||
|
||||
theorem lt_eq_false_of_lt {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α]
|
||||
(a b : α) : a < b → (b < a) = False := by
|
||||
simp; intro h₁ h₂
|
||||
have := lt_trans h₁ h₂
|
||||
have := Preorder.lt_irrefl a
|
||||
contradiction
|
||||
|
||||
theorem lt_eq_false_of_le {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α]
|
||||
(a b : α) : a ≤ b → (b < a) = False := by
|
||||
simp; intro h₁ h₂
|
||||
have := le_lt_trans h₁ h₂
|
||||
have := Preorder.lt_irrefl a
|
||||
contradiction
|
||||
|
||||
theorem le_eq_false_of_le_k {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α] [Ring α] [OrderedRing α]
|
||||
(a b : α) (k₁ k₂ : Int) : (k₂ + k₁).blt' 0 → a ≤ b + k₁ → (b ≤ a + k₂) = False := by
|
||||
intro h₁; simp; intro h₂ h₃
|
||||
|
|
|
|||
|
|
@ -25,23 +25,108 @@ where
|
|||
let p' ← getProof u p.w
|
||||
go (← mkTrans p' p v)
|
||||
|
||||
/--
|
||||
Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t.
|
||||
it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`,
|
||||
this function closes the current goal by constructing a proof of `False`.
|
||||
-/
|
||||
def setUnsat (u v : NodeId) (kuv : Weight) (huv : Expr) (kvu : Weight) : OrderM Unit := do
|
||||
let hvu ← mkProofForPath v u
|
||||
let u ← getExpr u
|
||||
let v ← getExpr v
|
||||
let h ← mkUnsatProof u v kuv huv kvu hvu
|
||||
closeGoal h
|
||||
|
||||
/-- Sets the new shortest distance `k` between nodes `u` and `v`. -/
|
||||
def setDist (u v : NodeId) (k : Weight) : OrderM Unit := do
|
||||
modifyStruct fun s => { s with
|
||||
targets := s.targets.modify u fun es => es.insert v k
|
||||
sources := s.sources.modify v fun es => es.insert u k
|
||||
}
|
||||
|
||||
def setProof (u v : NodeId) (p : ProofInfo) : OrderM Unit := do
|
||||
modifyStruct fun s => { s with
|
||||
proofs := s.proofs.modify u fun es => es.insert v p
|
||||
}
|
||||
|
||||
@[inline] def forEachSourceOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do
|
||||
(← getStruct).sources[u]!.forM f
|
||||
|
||||
@[inline] def forEachTargetOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do
|
||||
(← getStruct).targets[u]!.forM f
|
||||
|
||||
/-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/
|
||||
def isShorter (u v : NodeId) (k : Weight) : OrderM Bool := do
|
||||
if let some k' ← getDist? u v then
|
||||
return k < k'
|
||||
else
|
||||
return true
|
||||
|
||||
/-- Adds `p` to the list of things to be propagated. -/
|
||||
def pushToPropagate (p : ToPropagate) : OrderM Unit :=
|
||||
modifyStruct fun s => { s with propagate := p :: s.propagate }
|
||||
|
||||
/-
|
||||
def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Int) : OrderM Unit := do
|
||||
def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Weight) : OrderM Unit := do
|
||||
let kuv ← mkProofForPath u v
|
||||
let u ← getExpr u
|
||||
let v ← getExpr v
|
||||
pushEqTrue e <| mkPropagateEqTrueProof u v k kuv k'
|
||||
let h ← mkPropagateEqTrueProof u v k kuv k'
|
||||
pushEqTrue e h
|
||||
|
||||
private def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Int) : OrderM Unit := do
|
||||
def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Weight) : OrderM Unit := do
|
||||
let kuv ← mkProofForPath u v
|
||||
let u ← getExpr u
|
||||
let v ← getExpr v
|
||||
pushEqFalse e <| mkPropagateEqFalseProof u v k kuv k'
|
||||
let h ← mkPropagateEqFalseProof u v k kuv k'
|
||||
pushEqFalse e h
|
||||
|
||||
/-- Propagates all pending constraints and equalities and resets to "to do" list. -/
|
||||
private def propagatePending : OrderM Unit := do
|
||||
let todo := (← getStruct).propagate
|
||||
modifyStruct fun s => { s with propagate := [] }
|
||||
for p in todo do
|
||||
match p with
|
||||
| .eqTrue e u v k k' => propagateEqTrue e u v k k'
|
||||
| .eqFalse e u v k k' => propagateEqFalse e u v k k'
|
||||
| .eq u v =>
|
||||
let ue ← getExpr u
|
||||
let ve ← getExpr v
|
||||
unless (← isEqv ue ve) do
|
||||
let huv ← mkProofForPath u v
|
||||
let hvu ← mkProofForPath v u
|
||||
let h ← mkEqProofOfLeOfLe ue ve huv hvu
|
||||
pushEq ue ve h
|
||||
|
||||
def Cnstr.getWeight? (c : Cnstr α) : Option Weight :=
|
||||
match c.kind with
|
||||
| .le => some { k := c.k }
|
||||
| .lt => some { k := c.k, strict := true }
|
||||
| .eq => none
|
||||
|
||||
/--
|
||||
Given `e` represented by constraint `c` (from `u` to `v`).
|
||||
Checks whether `e = True` can be propagated using the path `u --(k)--> v`.
|
||||
If it can, adds a new entry to propagation list.
|
||||
-/
|
||||
def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
|
||||
let some k' := c.getWeight? | return false
|
||||
if k ≤ k' then
|
||||
pushToPropagate <| .eqTrue e u v k k'
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
/--
|
||||
Given `e` represented by constraint `c` (from `v` to `u`).
|
||||
Checks whether `e = False` can be propagated using the path `u --(k)--> v`.
|
||||
If it can, adds a new entry to propagation list.
|
||||
-/
|
||||
def checkEqFalse (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
|
||||
let some k' := c.getWeight? | return false
|
||||
if (k + k').isNeg then
|
||||
pushToPropagate <| .eqFalse e u v k k'
|
||||
return true
|
||||
return false
|
||||
|
||||
/--
|
||||
Auxiliary function for implementing theory propagation.
|
||||
|
|
@ -51,8 +136,7 @@ associated with `(u, v)` IF
|
|||
- `e` is already assigned, or
|
||||
- `f c e` returns true
|
||||
-/
|
||||
@[inline]
|
||||
private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do
|
||||
@[inline] def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do
|
||||
if let some cs := (← getStruct).cnstrsOf.find? (u, v) then
|
||||
let cs' ← cs.filterM fun (c, e) => do
|
||||
if (← isEqTrue e <||> isEqFalse e) then
|
||||
|
|
@ -61,6 +145,62 @@ private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM
|
|||
return !(← f c e)
|
||||
modifyStruct fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }
|
||||
|
||||
/-- Equality propagation. -/
|
||||
def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
|
||||
if !k.isZero then return ()
|
||||
let some k' ← getDist? v u | return ()
|
||||
if !k'.isZero then return ()
|
||||
let ue ← getExpr u
|
||||
let ve ← getExpr v
|
||||
if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
|
||||
if (← isEqv ue ve) then return ()
|
||||
pushToPropagate <| .eq u v
|
||||
|
||||
/-- Finds constrains and equalities to be propagated. -/
|
||||
def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do
|
||||
updateCnstrsOf u v fun c e => return !(← checkEqTrue u v k c e)
|
||||
updateCnstrsOf v u fun c e => return !(← checkEqFalse u v k c e)
|
||||
checkEq u v k
|
||||
|
||||
/--
|
||||
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
|
||||
`w` is a node in the path from `u` to `v` such that `(← getProof? w v)` is `some`
|
||||
-/
|
||||
def updateIfShorter (u v : NodeId) (k : Weight) (w : NodeId) : OrderM Unit := do
|
||||
if (← isShorter u v k) then
|
||||
setDist u v k
|
||||
setProof u v (← getProof w v)
|
||||
checkToPropagate u v k
|
||||
|
||||
/--
|
||||
Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
|
||||
if no negative cycle was created, updates the shortest distance of affected
|
||||
node pairs.
|
||||
-/
|
||||
def addEdge (u : NodeId) (v : NodeId) (k : Weight) (h : Expr) : OrderM Unit := do
|
||||
if (← isInconsistent) then return ()
|
||||
if let some k' ← getDist? v u then
|
||||
if (k + k').isNeg then
|
||||
setUnsat u v k h k'
|
||||
return ()
|
||||
if (← isShorter u v k) then
|
||||
setDist u v k
|
||||
setProof u v { w := u, k, proof := h }
|
||||
checkToPropagate u v k
|
||||
update
|
||||
propagatePending
|
||||
where
|
||||
update : OrderM Unit := do
|
||||
forEachTargetOf v fun j k₂ => do
|
||||
/- Check whether new path: `u -(k)-> v -(k₂)-> j` is shorter -/
|
||||
updateIfShorter u j (k+k₂) v
|
||||
forEachSourceOf u fun i k₁ => do
|
||||
/- Check whether new path: `i -(k₁)-> u -(k)-> v` is shorter -/
|
||||
updateIfShorter i v (k₁+k) u
|
||||
forEachTargetOf v fun j k₂ => do
|
||||
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
|
||||
updateIfShorter i j (k₁+k+k₂) v
|
||||
|
||||
def assertTrue (c : Cnstr NodeId) (p : Expr) : OrderM Unit := do
|
||||
trace[grind.order.assert] "{p} = True: {← c.pp}"
|
||||
|
||||
|
|
|
|||
|
|
@ -57,4 +57,7 @@ def getCnstr? (e : Expr) : OrderM (Option (Cnstr NodeId)) :=
|
|||
def isRing : OrderM Bool :=
|
||||
return (← getStruct).ringId?.isSome
|
||||
|
||||
def isPartialOrder : OrderM Bool :=
|
||||
return (← getStruct).isPartialInst?.isSome
|
||||
|
||||
end Lean.Meta.Grind.Order
|
||||
|
|
|
|||
|
|
@ -16,6 +16,13 @@ def mkLePreorderPrefix (declName : Name) : OrderM Expr := do
|
|||
let s ← getStruct
|
||||
return mkApp3 (mkConst declName [s.u]) s.type s.leInst s.isPreorderInst
|
||||
|
||||
/--
|
||||
Returns `declName α leInst isPartialInst`
|
||||
-/
|
||||
def mkLePartialPrefix (declName : Name) : OrderM Expr := do
|
||||
let s ← getStruct
|
||||
return mkApp3 (mkConst declName [s.u]) s.type s.leInst s.isPartialInst?.get!
|
||||
|
||||
/--
|
||||
Returns `declName α leInst ltInst lawfulOrderLtInst`
|
||||
-/
|
||||
|
|
@ -127,8 +134,13 @@ public def mkPropagateEqTrueProof (u v : Expr) (k : Weight) (huv : Expr) (k' : W
|
|||
/--
|
||||
`u < v → (v ≤ u) = False
|
||||
-/
|
||||
def mkPropagateEqFalseProofCore (u v : Expr) (huv : Expr) : OrderM Expr := do
|
||||
let h ← mkLeLtPreorderPrefix ``Grind.Order.le_eq_false_of_lt
|
||||
def mkPropagateEqFalseProofCore (u v : Expr) (k : Weight) (huv : Expr) (k' : Weight) : OrderM Expr := do
|
||||
let declName := match k'.strict, k.strict with
|
||||
| false, false => unreachable!
|
||||
| false, true => ``Grind.Order.le_eq_false_of_lt
|
||||
| true, false => ``Grind.Order.lt_eq_false_of_le
|
||||
| true, true => ``Grind.Order.lt_eq_false_of_lt
|
||||
let h ← mkLeLtPreorderPrefix declName
|
||||
return mkApp3 h u v huv
|
||||
|
||||
def mkPropagateEqFalseProofOffset (u v : Expr) (k : Weight) (huv : Expr) (k' : Weight) : OrderM Expr := do
|
||||
|
|
@ -148,7 +160,7 @@ public def mkPropagateEqFalseProof (u v : Expr) (k : Weight) (huv : Expr) (k' :
|
|||
if (← isRing) then
|
||||
mkPropagateEqFalseProofOffset u v k huv k'
|
||||
else
|
||||
mkPropagateEqFalseProofCore u v huv
|
||||
mkPropagateEqFalseProofCore u v k huv k'
|
||||
|
||||
def mkUnsatProofCore (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do
|
||||
let h ← mkTransCoreProof u v u k₁.strict k₂.strict h₁ h₂
|
||||
|
|
@ -170,10 +182,14 @@ Returns a proof of `False` using a negative cycle composed of
|
|||
- `u --(k₁)--> v` with proof `h₁`
|
||||
- `v --(k₂)--> u` with proof `h₂`
|
||||
-/
|
||||
def mkUnsatProof (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do
|
||||
public def mkUnsatProof (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do
|
||||
if (← isRing) then
|
||||
mkUnsatProofOffset u v k₁ h₁ k₂ h₂
|
||||
else
|
||||
mkUnsatProofCore u v k₁ h₁ k₂ h₂
|
||||
|
||||
public def mkEqProofOfLeOfLe (u v : Expr) (h₁ : Expr) (h₂ : Expr) : OrderM Expr := do
|
||||
let h ← mkLePartialPrefix ``Grind.Order.eq_of_le_of_le
|
||||
return mkApp4 h u v h₁ h₂
|
||||
|
||||
end Lean.Meta.Grind.Order
|
||||
|
|
|
|||
|
|
@ -39,33 +39,6 @@ structure Weight where
|
|||
strict := false
|
||||
deriving Inhabited
|
||||
|
||||
def Weight.compare (a b : Weight) : Ordering :=
|
||||
if a.k < b.k then
|
||||
.lt
|
||||
else if b.k > a.k then
|
||||
.gt
|
||||
else if a.strict == b.strict then
|
||||
.eq
|
||||
else if !a.strict && b.strict then
|
||||
.lt
|
||||
else
|
||||
.gt
|
||||
|
||||
instance : Ord Weight where
|
||||
compare := Weight.compare
|
||||
|
||||
instance : LE Weight where
|
||||
le a b := compare a b ≠ .gt
|
||||
|
||||
instance : LT Weight where
|
||||
lt a b := compare a b = .lt
|
||||
|
||||
instance : DecidableLE Weight :=
|
||||
fun a b => inferInstanceAs (Decidable (compare a b ≠ .gt))
|
||||
|
||||
instance : DecidableLT Weight :=
|
||||
fun a b => inferInstanceAs (Decidable (compare a b = .lt))
|
||||
|
||||
/-- Auxiliary structure used for proof extraction. -/
|
||||
structure ProofInfo where
|
||||
w : NodeId
|
||||
|
|
@ -82,8 +55,8 @@ Thus, we store the information to be propagated into a list.
|
|||
See field `propagate` in `State`.
|
||||
-/
|
||||
inductive ToPropagate where
|
||||
| eqTrue (e : Expr) (u v : NodeId) (k k' : Int)
|
||||
| eqFalse (e : Expr) (u v : NodeId) (k k' : Int)
|
||||
| eqTrue (e : Expr) (u v : NodeId) (k k' : Weight)
|
||||
| eqFalse (e : Expr) (u v : NodeId) (k k' : Weight)
|
||||
| eq (u v : NodeId)
|
||||
deriving Inhabited
|
||||
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ module
|
|||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Order.OrderM
|
||||
import Lean.Meta.Tactic.Grind.Arith.Util
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Order
|
||||
|
||||
public def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do
|
||||
def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do
|
||||
let u ← getExpr c.u
|
||||
let v ← getExpr c.v
|
||||
let op := match c.kind with
|
||||
|
|
@ -21,4 +22,49 @@ public def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do
|
|||
else
|
||||
return m!"{Arith.quoteIfArithTerm u} {op} {Arith.quoteIfArithTerm v}"
|
||||
|
||||
def Weight.compare (a b : Weight) : Ordering :=
|
||||
if a.k < b.k then
|
||||
.lt
|
||||
else if b.k > a.k then
|
||||
.gt
|
||||
else if a.strict == b.strict then
|
||||
.eq
|
||||
else if a.strict && !b.strict then
|
||||
/-
|
||||
**Note**: Recall that we view a constraint of the
|
||||
form `x < y + k` as `x ≤ y + (k - ε)` where `ε` is
|
||||
an "infinitesimal" positive value.
|
||||
Thus, `k - ε < k`
|
||||
-/
|
||||
.lt
|
||||
else
|
||||
.gt
|
||||
|
||||
instance : Ord Weight where
|
||||
compare := Weight.compare
|
||||
|
||||
instance : LE Weight where
|
||||
le a b := compare a b ≠ .gt
|
||||
|
||||
instance : LT Weight where
|
||||
lt a b := compare a b = .lt
|
||||
|
||||
instance : DecidableLE Weight :=
|
||||
fun a b => inferInstanceAs (Decidable (compare a b ≠ .gt))
|
||||
|
||||
instance : DecidableLT Weight :=
|
||||
fun a b => inferInstanceAs (Decidable (compare a b = .lt))
|
||||
|
||||
def Weight.add (a b : Weight) : Weight :=
|
||||
{ k := a.k + b.k, strict := a.strict || b.strict }
|
||||
|
||||
instance : Add Weight where
|
||||
add := Weight.add
|
||||
|
||||
def Weight.isNeg (a : Weight) : Bool :=
|
||||
a.k < 0 || (a.k == 0 && a.strict)
|
||||
|
||||
def Weight.isZero (a : Weight) : Bool :=
|
||||
a.k == 0 && !a.strict
|
||||
|
||||
end Lean.Meta.Grind.Order
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue