From 0504e32bb75524b0438e35c339af26e050178ff1 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 27 Sep 2025 11:18:41 -0700 Subject: [PATCH] feat: add `addEdge` to `grind order` (#10596) This PR implements the function for adding new edges to the graph used by `grind order`. The graph maintains the transitive closure of all asserted constraints. --- src/Init/Grind/Order.lean | 14 ++ src/Lean/Meta/Tactic/Grind/Order/Assert.lean | 154 ++++++++++++++++++- src/Lean/Meta/Tactic/Grind/Order/OrderM.lean | 3 + src/Lean/Meta/Tactic/Grind/Order/Proof.lean | 24 ++- src/Lean/Meta/Tactic/Grind/Order/Types.lean | 31 +--- src/Lean/Meta/Tactic/Grind/Order/Util.lean | 48 +++++- 6 files changed, 233 insertions(+), 41 deletions(-) diff --git a/src/Init/Grind/Order.lean b/src/Init/Grind/Order.lean index 917d1854ae..e362b674f1 100644 --- a/src/Init/Grind/Order.lean +++ b/src/Init/Grind/Order.lean @@ -196,6 +196,20 @@ theorem le_eq_false_of_lt {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPre have := Preorder.lt_irrefl a contradiction +theorem lt_eq_false_of_lt {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α] + (a b : α) : a < b → (b < a) = False := by + simp; intro h₁ h₂ + have := lt_trans h₁ h₂ + have := Preorder.lt_irrefl a + contradiction + +theorem lt_eq_false_of_le {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α] + (a b : α) : a ≤ b → (b < a) = False := by + simp; intro h₁ h₂ + have := le_lt_trans h₁ h₂ + have := Preorder.lt_irrefl a + contradiction + theorem le_eq_false_of_le_k {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α] [Ring α] [OrderedRing α] (a b : α) (k₁ k₂ : Int) : (k₂ + k₁).blt' 0 → a ≤ b + k₁ → (b ≤ a + k₂) = False := by intro h₁; simp; intro h₂ h₃ diff --git a/src/Lean/Meta/Tactic/Grind/Order/Assert.lean b/src/Lean/Meta/Tactic/Grind/Order/Assert.lean index 9e5f30cffc..8ac3884f06 100644 --- a/src/Lean/Meta/Tactic/Grind/Order/Assert.lean +++ b/src/Lean/Meta/Tactic/Grind/Order/Assert.lean @@ -25,23 +25,108 @@ where let p' ← getProof u p.w go (← mkTrans p' p v) +/-- +Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t. +it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`, +this function closes the current goal by constructing a proof of `False`. +-/ +def setUnsat (u v : NodeId) (kuv : Weight) (huv : Expr) (kvu : Weight) : OrderM Unit := do + let hvu ← mkProofForPath v u + let u ← getExpr u + let v ← getExpr v + let h ← mkUnsatProof u v kuv huv kvu hvu + closeGoal h + +/-- Sets the new shortest distance `k` between nodes `u` and `v`. -/ +def setDist (u v : NodeId) (k : Weight) : OrderM Unit := do + modifyStruct fun s => { s with + targets := s.targets.modify u fun es => es.insert v k + sources := s.sources.modify v fun es => es.insert u k + } + +def setProof (u v : NodeId) (p : ProofInfo) : OrderM Unit := do + modifyStruct fun s => { s with + proofs := s.proofs.modify u fun es => es.insert v p + } + +@[inline] def forEachSourceOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do + (← getStruct).sources[u]!.forM f + +@[inline] def forEachTargetOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do + (← getStruct).targets[u]!.forM f + +/-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/ +def isShorter (u v : NodeId) (k : Weight) : OrderM Bool := do + if let some k' ← getDist? u v then + return k < k' + else + return true + /-- Adds `p` to the list of things to be propagated. -/ def pushToPropagate (p : ToPropagate) : OrderM Unit := modifyStruct fun s => { s with propagate := p :: s.propagate } -/- -def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Int) : OrderM Unit := do +def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Weight) : OrderM Unit := do let kuv ← mkProofForPath u v let u ← getExpr u let v ← getExpr v - pushEqTrue e <| mkPropagateEqTrueProof u v k kuv k' + let h ← mkPropagateEqTrueProof u v k kuv k' + pushEqTrue e h -private def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Int) : OrderM Unit := do +def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Weight) : OrderM Unit := do let kuv ← mkProofForPath u v let u ← getExpr u let v ← getExpr v - pushEqFalse e <| mkPropagateEqFalseProof u v k kuv k' + let h ← mkPropagateEqFalseProof u v k kuv k' + pushEqFalse e h + +/-- Propagates all pending constraints and equalities and resets to "to do" list. -/ +private def propagatePending : OrderM Unit := do + let todo := (← getStruct).propagate + modifyStruct fun s => { s with propagate := [] } + for p in todo do + match p with + | .eqTrue e u v k k' => propagateEqTrue e u v k k' + | .eqFalse e u v k k' => propagateEqFalse e u v k k' + | .eq u v => + let ue ← getExpr u + let ve ← getExpr v + unless (← isEqv ue ve) do + let huv ← mkProofForPath u v + let hvu ← mkProofForPath v u + let h ← mkEqProofOfLeOfLe ue ve huv hvu + pushEq ue ve h + +def Cnstr.getWeight? (c : Cnstr α) : Option Weight := + match c.kind with + | .le => some { k := c.k } + | .lt => some { k := c.k, strict := true } + | .eq => none + +/-- +Given `e` represented by constraint `c` (from `u` to `v`). +Checks whether `e = True` can be propagated using the path `u --(k)--> v`. +If it can, adds a new entry to propagation list. -/ +def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do + let some k' := c.getWeight? | return false + if k ≤ k' then + pushToPropagate <| .eqTrue e u v k k' + return true + else + return false + +/-- +Given `e` represented by constraint `c` (from `v` to `u`). +Checks whether `e = False` can be propagated using the path `u --(k)--> v`. +If it can, adds a new entry to propagation list. +-/ +def checkEqFalse (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do + let some k' := c.getWeight? | return false + if (k + k').isNeg then + pushToPropagate <| .eqFalse e u v k k' + return true + return false /-- Auxiliary function for implementing theory propagation. @@ -51,8 +136,7 @@ associated with `(u, v)` IF - `e` is already assigned, or - `f c e` returns true -/ -@[inline] -private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do +@[inline] def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do if let some cs := (← getStruct).cnstrsOf.find? (u, v) then let cs' ← cs.filterM fun (c, e) => do if (← isEqTrue e <||> isEqFalse e) then @@ -61,6 +145,62 @@ private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM return !(← f c e) modifyStruct fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' } +/-- Equality propagation. -/ +def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do + if !k.isZero then return () + let some k' ← getDist? v u | return () + if !k'.isZero then return () + let ue ← getExpr u + let ve ← getExpr v + if (← alreadyInternalized ue <&&> alreadyInternalized ve) then + if (← isEqv ue ve) then return () + pushToPropagate <| .eq u v + +/-- Finds constrains and equalities to be propagated. -/ +def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do + updateCnstrsOf u v fun c e => return !(← checkEqTrue u v k c e) + updateCnstrsOf v u fun c e => return !(← checkEqFalse u v k c e) + checkEq u v k + +/-- +If `isShorter u v k`, updates the shortest distance between `u` and `v`. +`w` is a node in the path from `u` to `v` such that `(← getProof? w v)` is `some` +-/ +def updateIfShorter (u v : NodeId) (k : Weight) (w : NodeId) : OrderM Unit := do + if (← isShorter u v k) then + setDist u v k + setProof u v (← getProof w v) + checkToPropagate u v k + +/-- +Adds an edge `u --(k) --> v` justified by the proof term `p`, and then +if no negative cycle was created, updates the shortest distance of affected +node pairs. +-/ +def addEdge (u : NodeId) (v : NodeId) (k : Weight) (h : Expr) : OrderM Unit := do + if (← isInconsistent) then return () + if let some k' ← getDist? v u then + if (k + k').isNeg then + setUnsat u v k h k' + return () + if (← isShorter u v k) then + setDist u v k + setProof u v { w := u, k, proof := h } + checkToPropagate u v k + update + propagatePending +where + update : OrderM Unit := do + forEachTargetOf v fun j k₂ => do + /- Check whether new path: `u -(k)-> v -(k₂)-> j` is shorter -/ + updateIfShorter u j (k+k₂) v + forEachSourceOf u fun i k₁ => do + /- Check whether new path: `i -(k₁)-> u -(k)-> v` is shorter -/ + updateIfShorter i v (k₁+k) u + forEachTargetOf v fun j k₂ => do + /- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/ + updateIfShorter i j (k₁+k+k₂) v + def assertTrue (c : Cnstr NodeId) (p : Expr) : OrderM Unit := do trace[grind.order.assert] "{p} = True: {← c.pp}" diff --git a/src/Lean/Meta/Tactic/Grind/Order/OrderM.lean b/src/Lean/Meta/Tactic/Grind/Order/OrderM.lean index 0e92e22c3c..eac5486756 100644 --- a/src/Lean/Meta/Tactic/Grind/Order/OrderM.lean +++ b/src/Lean/Meta/Tactic/Grind/Order/OrderM.lean @@ -57,4 +57,7 @@ def getCnstr? (e : Expr) : OrderM (Option (Cnstr NodeId)) := def isRing : OrderM Bool := return (← getStruct).ringId?.isSome +def isPartialOrder : OrderM Bool := + return (← getStruct).isPartialInst?.isSome + end Lean.Meta.Grind.Order diff --git a/src/Lean/Meta/Tactic/Grind/Order/Proof.lean b/src/Lean/Meta/Tactic/Grind/Order/Proof.lean index bf33c6d25e..b5d0dbce75 100644 --- a/src/Lean/Meta/Tactic/Grind/Order/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Order/Proof.lean @@ -16,6 +16,13 @@ def mkLePreorderPrefix (declName : Name) : OrderM Expr := do let s ← getStruct return mkApp3 (mkConst declName [s.u]) s.type s.leInst s.isPreorderInst +/-- +Returns `declName α leInst isPartialInst` +-/ +def mkLePartialPrefix (declName : Name) : OrderM Expr := do + let s ← getStruct + return mkApp3 (mkConst declName [s.u]) s.type s.leInst s.isPartialInst?.get! + /-- Returns `declName α leInst ltInst lawfulOrderLtInst` -/ @@ -127,8 +134,13 @@ public def mkPropagateEqTrueProof (u v : Expr) (k : Weight) (huv : Expr) (k' : W /-- `u < v → (v ≤ u) = False -/ -def mkPropagateEqFalseProofCore (u v : Expr) (huv : Expr) : OrderM Expr := do - let h ← mkLeLtPreorderPrefix ``Grind.Order.le_eq_false_of_lt +def mkPropagateEqFalseProofCore (u v : Expr) (k : Weight) (huv : Expr) (k' : Weight) : OrderM Expr := do + let declName := match k'.strict, k.strict with + | false, false => unreachable! + | false, true => ``Grind.Order.le_eq_false_of_lt + | true, false => ``Grind.Order.lt_eq_false_of_le + | true, true => ``Grind.Order.lt_eq_false_of_lt + let h ← mkLeLtPreorderPrefix declName return mkApp3 h u v huv def mkPropagateEqFalseProofOffset (u v : Expr) (k : Weight) (huv : Expr) (k' : Weight) : OrderM Expr := do @@ -148,7 +160,7 @@ public def mkPropagateEqFalseProof (u v : Expr) (k : Weight) (huv : Expr) (k' : if (← isRing) then mkPropagateEqFalseProofOffset u v k huv k' else - mkPropagateEqFalseProofCore u v huv + mkPropagateEqFalseProofCore u v k huv k' def mkUnsatProofCore (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do let h ← mkTransCoreProof u v u k₁.strict k₂.strict h₁ h₂ @@ -170,10 +182,14 @@ Returns a proof of `False` using a negative cycle composed of - `u --(k₁)--> v` with proof `h₁` - `v --(k₂)--> u` with proof `h₂` -/ -def mkUnsatProof (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do +public def mkUnsatProof (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do if (← isRing) then mkUnsatProofOffset u v k₁ h₁ k₂ h₂ else mkUnsatProofCore u v k₁ h₁ k₂ h₂ +public def mkEqProofOfLeOfLe (u v : Expr) (h₁ : Expr) (h₂ : Expr) : OrderM Expr := do + let h ← mkLePartialPrefix ``Grind.Order.eq_of_le_of_le + return mkApp4 h u v h₁ h₂ + end Lean.Meta.Grind.Order diff --git a/src/Lean/Meta/Tactic/Grind/Order/Types.lean b/src/Lean/Meta/Tactic/Grind/Order/Types.lean index 61420b79fb..33c47e76a2 100644 --- a/src/Lean/Meta/Tactic/Grind/Order/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Order/Types.lean @@ -39,33 +39,6 @@ structure Weight where strict := false deriving Inhabited -def Weight.compare (a b : Weight) : Ordering := - if a.k < b.k then - .lt - else if b.k > a.k then - .gt - else if a.strict == b.strict then - .eq - else if !a.strict && b.strict then - .lt - else - .gt - -instance : Ord Weight where - compare := Weight.compare - -instance : LE Weight where - le a b := compare a b ≠ .gt - -instance : LT Weight where - lt a b := compare a b = .lt - -instance : DecidableLE Weight := - fun a b => inferInstanceAs (Decidable (compare a b ≠ .gt)) - -instance : DecidableLT Weight := - fun a b => inferInstanceAs (Decidable (compare a b = .lt)) - /-- Auxiliary structure used for proof extraction. -/ structure ProofInfo where w : NodeId @@ -82,8 +55,8 @@ Thus, we store the information to be propagated into a list. See field `propagate` in `State`. -/ inductive ToPropagate where - | eqTrue (e : Expr) (u v : NodeId) (k k' : Int) - | eqFalse (e : Expr) (u v : NodeId) (k k' : Int) + | eqTrue (e : Expr) (u v : NodeId) (k k' : Weight) + | eqFalse (e : Expr) (u v : NodeId) (k k' : Weight) | eq (u v : NodeId) deriving Inhabited diff --git a/src/Lean/Meta/Tactic/Grind/Order/Util.lean b/src/Lean/Meta/Tactic/Grind/Order/Util.lean index 37fe6a2311..08b90a5ed7 100644 --- a/src/Lean/Meta/Tactic/Grind/Order/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/Order/Util.lean @@ -7,9 +7,10 @@ module prelude public import Lean.Meta.Tactic.Grind.Order.OrderM import Lean.Meta.Tactic.Grind.Arith.Util +public section namespace Lean.Meta.Grind.Order -public def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do +def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do let u ← getExpr c.u let v ← getExpr c.v let op := match c.kind with @@ -21,4 +22,49 @@ public def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do else return m!"{Arith.quoteIfArithTerm u} {op} {Arith.quoteIfArithTerm v}" +def Weight.compare (a b : Weight) : Ordering := + if a.k < b.k then + .lt + else if b.k > a.k then + .gt + else if a.strict == b.strict then + .eq + else if a.strict && !b.strict then + /- + **Note**: Recall that we view a constraint of the + form `x < y + k` as `x ≤ y + (k - ε)` where `ε` is + an "infinitesimal" positive value. + Thus, `k - ε < k` + -/ + .lt + else + .gt + +instance : Ord Weight where + compare := Weight.compare + +instance : LE Weight where + le a b := compare a b ≠ .gt + +instance : LT Weight where + lt a b := compare a b = .lt + +instance : DecidableLE Weight := + fun a b => inferInstanceAs (Decidable (compare a b ≠ .gt)) + +instance : DecidableLT Weight := + fun a b => inferInstanceAs (Decidable (compare a b = .lt)) + +def Weight.add (a b : Weight) : Weight := + { k := a.k + b.k, strict := a.strict || b.strict } + +instance : Add Weight where + add := Weight.add + +def Weight.isNeg (a : Weight) : Bool := + a.k < 0 || (a.k == 0 && a.strict) + +def Weight.isZero (a : Weight) : Bool := + a.k == 0 && !a.strict + end Lean.Meta.Grind.Order