From 56f3ca6fc74ad5bfbc0feafc4d5a48e87571db62 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 21 Oct 2025 09:38:39 -0700 Subject: [PATCH] fix: propagation in `grind order` (#10877) This PR fixes theory propagation issue in `grind order`. --- src/Lean/Meta/Tactic/Grind/Order/Assert.lean | 28 +++- tests/lean/run/grind_indexmap_trace.lean | 13 +- tests/lean/run/grind_order_issue.lean | 128 +++++++++++++++++++ 3 files changed, 158 insertions(+), 11 deletions(-) create mode 100644 tests/lean/run/grind_order_issue.lean diff --git a/src/Lean/Meta/Tactic/Grind/Order/Assert.lean b/src/Lean/Meta/Tactic/Grind/Order/Assert.lean index 6bc77bf1d0..f15052c2bd 100644 --- a/src/Lean/Meta/Tactic/Grind/Order/Assert.lean +++ b/src/Lean/Meta/Tactic/Grind/Order/Assert.lean @@ -111,13 +111,23 @@ def propagatePending : OrderM Unit := do let h ← mkEqProofOfLeOfLe ue ve huv hvu pushEq ue ve h +/-- +Returns `true` if `e` is already `True` in the `grind` core. +Recall that `e` may be an auxiliary term created for a term `e'` (see `cnstrsMapInv`). +-/ +private def isAlreadyTrue (e : Expr) : OrderM Bool := do + if let some (e', _) := (← get').cnstrsMapInv.find? { expr := e } then + alreadyInternalized e' <&&> isEqTrue e' + else + alreadyInternalized e <&&> isEqTrue e + /-- Given `e` represented by constraint `c` (from `u` to `v`). Checks whether `e = True` can be propagated using the path `u --(k)--> v`. If it can, adds a new entry to propagation list. -/ def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do - if (← alreadyInternalized e <&&> isEqTrue e) then return true + if (← isAlreadyTrue e) then return true let k' := c.getWeight trace[grind.debug.order.check_eq_true] "{← getExpr u}, {← getExpr v}, {k}, {k'}, {← c.pp}" if k ≤ k' then @@ -126,13 +136,23 @@ def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : Orde else return false +/-- +Returns `true` if `e` is already `False` in the `grind` core. +Recall that `e` may be an auxiliary term created for a term `e'` (see `cnstrsMapInv`). +-/ +private def isAlreadyFalse (e : Expr) : OrderM Bool := do + if let some (e', _) := (← get').cnstrsMapInv.find? { expr := e } then + alreadyInternalized e' <&&> isEqFalse e' + else + alreadyInternalized e <&&> isEqFalse e + /-- Given `e` represented by constraint `c` (from `v` to `u`). Checks whether `e = False` can be propagated using the path `u --(k)--> v`. If it can, adds a new entry to propagation list. -/ def checkEqFalse (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do - if (← alreadyInternalized e <&&> isEqFalse e) then return true + if (← isAlreadyFalse e) then return true let k' := c.getWeight trace[grind.debug.order.check_eq_false] "{← getExpr u}, {← getExpr v}, {k}, {k'} {← c.pp}" if (k + k').isNeg then @@ -168,8 +188,8 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do /-- Finds constrains and equalities to be propagated. -/ def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do - updateCnstrsOf u v fun c e => return !(← checkEqTrue u v k c e) - updateCnstrsOf v u fun c e => return !(← checkEqFalse u v k c e) + updateCnstrsOf u v fun c e => checkEqTrue u v k c e + updateCnstrsOf v u fun c e => checkEqFalse u v k c e checkEq u v k /-- diff --git a/tests/lean/run/grind_indexmap_trace.lean b/tests/lean/run/grind_indexmap_trace.lean index 74ad0ca261..450abcceac 100644 --- a/tests/lean/run/grind_indexmap_trace.lean +++ b/tests/lean/run/grind_indexmap_trace.lean @@ -239,7 +239,7 @@ example (m : IndexMap α β) (a a' : α) (b : β) : /-- info: Try this: [apply] ⏎ - instantiate only [= getElem_def, insert] + instantiate approx [= getElem_def, = mem_indices_of_mem, insert] instantiate only [= getElem?_neg, = getElem?_pos] cases #f590 next => @@ -249,8 +249,7 @@ info: Try this: instantiate only [= Array.getElem_set] next => instantiate only - instantiate approx [= HashMap.getElem_insert, = Array.size_push, size, = Array.getElem_push, - = HashMap.contains_insert, = HashMap.mem_insert, = Array.size_push] + instantiate only [= Array.getElem_push, size, = HashMap.getElem_insert, = HashMap.mem_insert] next => instantiate only [= getElem_def, = mem_indices_of_mem] instantiate only [usr getElem_indices_lt] @@ -272,7 +271,8 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) : example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) : (m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by grind => - instantiate only [= getElem_def, insert] + -- **TODO**: Check approx here + instantiate approx [= getElem_def, = mem_indices_of_mem, insert] instantiate only [= getElem?_neg, = getElem?_pos] cases #f590 next => @@ -282,9 +282,8 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) : instantiate only [= Array.getElem_set] next => instantiate only - -- **TODO**: Investigate why we need `approx` here - instantiate approx [= HashMap.getElem_insert, = Array.size_push, size, = Array.getElem_push, - = HashMap.contains_insert, = HashMap.mem_insert, = Array.size_push] + instantiate only [= Array.getElem_push, size, = HashMap.getElem_insert, + = HashMap.mem_insert] next => instantiate only [= getElem_def, = mem_indices_of_mem] instantiate only [usr getElem_indices_lt] diff --git a/tests/lean/run/grind_order_issue.lean b/tests/lean/run/grind_order_issue.lean new file mode 100644 index 0000000000..de17435aff --- /dev/null +++ b/tests/lean/run/grind_order_issue.lean @@ -0,0 +1,128 @@ +import Std.Data.HashMap +set_option warn.sorry false +macro_rules | `(tactic| get_elem_tactic_extensible) => `(tactic| grind) + +open Std + +structure IndexMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where + private indices : HashMap α Nat + private keys : Array α + private values : Array β + private size_keys' : keys.size = values.size := by grind + private WF : ∀ (i : Nat) (a : α), keys[i]? = some a ↔ indices[a]? = some i := by grind + +namespace IndexMap + +variable {α : Type u} {β : Type v} [BEq α] [Hashable α] +variable {m : IndexMap α β} {a : α} {b : β} {i : Nat} + +@[inline] def size (m : IndexMap α β) : Nat := + m.values.size + +@[local grind =] private theorem size_keys : m.keys.size = m.size := m.size_keys' + +def emptyWithCapacity (capacity := 8) : IndexMap α β where + indices := HashMap.emptyWithCapacity capacity + keys := Array.emptyWithCapacity capacity + values := Array.emptyWithCapacity capacity + +instance : EmptyCollection (IndexMap α β) where + emptyCollection := emptyWithCapacity + +instance : Inhabited (IndexMap α β) where + default := ∅ + +@[inline] def contains (m : IndexMap α β) + (a : α) : Bool := + m.indices.contains a + +instance : Membership α (IndexMap α β) where + mem m a := a ∈ m.indices + +instance {m : IndexMap α β} {a : α} : Decidable (a ∈ m) := + inferInstanceAs (Decidable (a ∈ m.indices)) + +@[local grind =] private theorem mem_indices_of_mem {m : IndexMap α β} {a : α} : + a ∈ m ↔ a ∈ m.indices := Iff.rfl + +@[inline] def findIdx? (m : IndexMap α β) (a : α) : Option Nat := m.indices[a]? + +@[inline] def findIdx (m : IndexMap α β) (a : α) (h : a ∈ m := by get_elem_tactic) : Nat := m.indices[a] + +@[inline] def getIdx? (m : IndexMap α β) (i : Nat) : Option β := m.values[i]? + +@[inline] def getIdx (m : IndexMap α β) (i : Nat) (h : i < m.size := by get_elem_tactic) : β := + m.values[i] + +variable [LawfulBEq α] [LawfulHashable α] + +attribute [local grind _=_] IndexMap.WF + +private theorem getElem_indices_lt {h : a ∈ m} : m.indices[a] < m.size := by + have : m.indices[a]? = some m.indices[a] := by grind + grind + +grind_pattern getElem_indices_lt => m.indices[a] + +attribute [local grind] size + +instance : GetElem? (IndexMap α β) α β (fun m a => a ∈ m) where + getElem m a h := m.values[m.indices[a]'h] + getElem? m a := m.indices[a]?.bind (fun i => (m.values[i]?)) + getElem! m a := m.indices[a]?.bind (fun i => (m.values[i]?)) |>.getD default + +@[local grind =] private theorem getElem_def (m : IndexMap α β) (a : α) (h : a ∈ m) : m[a] = m.values[m.indices[a]'h] := rfl +@[local grind =] private theorem getElem?_def (m : IndexMap α β) (a : α) : + m[a]? = m.indices[a]?.bind (fun i => (m.values[i]?)) := rfl +@[local grind =] private theorem getElem!_def [Inhabited β] (m : IndexMap α β) (a : α) : + m[a]! = (m.indices[a]?.bind (fun i => (m.values[i]?))).getD default := rfl + +instance : LawfulGetElem (IndexMap α β) α β (fun m a => a ∈ m) where + getElem?_def := by grind + getElem!_def := by grind + +@[inline] def insert [LawfulBEq α] (m : IndexMap α β) (a : α) (b : β) : + IndexMap α β := + match h : m.indices[a]? with + | some i => + { indices := m.indices + keys := m.keys.set i a + values := m.values.set i b } + | none => + { indices := m.indices.insert a m.size + keys := m.keys.push a + values := m.values.push b } + +/-! ### Verification theorems -/ + +attribute [local grind] getIdx findIdx insert + +example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) : + (m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by + grind -offset -ring -linarith -cutsat => + instantiate only [= getElem_def, insert] + cases #f590 + next => + cases #ffdf + next => sorry + next => + instantiate only + instantiate only [= HashMap.getElem_insert] + instantiate only [= size] + instantiate only [= Array.getElem_push, = mem_indices_of_mem] + next => sorry + +example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) : + (m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by + grind -offset -ring -linarith -cutsat => + instantiate only [= getElem_def, insert] + cases #f590 + next => + cases #ffdf + next => sorry + next => + instantiate only + instantiate only [= HashMap.getElem_insert] + instantiate only [= size] + instantiate only [= mem_indices_of_mem, = Array.getElem_push] + next => sorry