fix: propagation in grind order (#10877)
This PR fixes theory propagation issue in `grind order`.
This commit is contained in:
parent
94cb32bc46
commit
56f3ca6fc7
3 changed files with 158 additions and 11 deletions
|
|
@ -111,13 +111,23 @@ def propagatePending : OrderM Unit := do
|
|||
let h ← mkEqProofOfLeOfLe ue ve huv hvu
|
||||
pushEq ue ve h
|
||||
|
||||
/--
|
||||
Returns `true` if `e` is already `True` in the `grind` core.
|
||||
Recall that `e` may be an auxiliary term created for a term `e'` (see `cnstrsMapInv`).
|
||||
-/
|
||||
private def isAlreadyTrue (e : Expr) : OrderM Bool := do
|
||||
if let some (e', _) := (← get').cnstrsMapInv.find? { expr := e } then
|
||||
alreadyInternalized e' <&&> isEqTrue e'
|
||||
else
|
||||
alreadyInternalized e <&&> isEqTrue e
|
||||
|
||||
/--
|
||||
Given `e` represented by constraint `c` (from `u` to `v`).
|
||||
Checks whether `e = True` can be propagated using the path `u --(k)--> v`.
|
||||
If it can, adds a new entry to propagation list.
|
||||
-/
|
||||
def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
|
||||
if (← alreadyInternalized e <&&> isEqTrue e) then return true
|
||||
if (← isAlreadyTrue e) then return true
|
||||
let k' := c.getWeight
|
||||
trace[grind.debug.order.check_eq_true] "{← getExpr u}, {← getExpr v}, {k}, {k'}, {← c.pp}"
|
||||
if k ≤ k' then
|
||||
|
|
@ -126,13 +136,23 @@ def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : Orde
|
|||
else
|
||||
return false
|
||||
|
||||
/--
|
||||
Returns `true` if `e` is already `False` in the `grind` core.
|
||||
Recall that `e` may be an auxiliary term created for a term `e'` (see `cnstrsMapInv`).
|
||||
-/
|
||||
private def isAlreadyFalse (e : Expr) : OrderM Bool := do
|
||||
if let some (e', _) := (← get').cnstrsMapInv.find? { expr := e } then
|
||||
alreadyInternalized e' <&&> isEqFalse e'
|
||||
else
|
||||
alreadyInternalized e <&&> isEqFalse e
|
||||
|
||||
/--
|
||||
Given `e` represented by constraint `c` (from `v` to `u`).
|
||||
Checks whether `e = False` can be propagated using the path `u --(k)--> v`.
|
||||
If it can, adds a new entry to propagation list.
|
||||
-/
|
||||
def checkEqFalse (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
|
||||
if (← alreadyInternalized e <&&> isEqFalse e) then return true
|
||||
if (← isAlreadyFalse e) then return true
|
||||
let k' := c.getWeight
|
||||
trace[grind.debug.order.check_eq_false] "{← getExpr u}, {← getExpr v}, {k}, {k'} {← c.pp}"
|
||||
if (k + k').isNeg then
|
||||
|
|
@ -168,8 +188,8 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
|
|||
|
||||
/-- Finds constrains and equalities to be propagated. -/
|
||||
def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do
|
||||
updateCnstrsOf u v fun c e => return !(← checkEqTrue u v k c e)
|
||||
updateCnstrsOf v u fun c e => return !(← checkEqFalse u v k c e)
|
||||
updateCnstrsOf u v fun c e => checkEqTrue u v k c e
|
||||
updateCnstrsOf v u fun c e => checkEqFalse u v k c e
|
||||
checkEq u v k
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ example (m : IndexMap α β) (a a' : α) (b : β) :
|
|||
/--
|
||||
info: Try this:
|
||||
[apply] ⏎
|
||||
instantiate only [= getElem_def, insert]
|
||||
instantiate approx [= getElem_def, = mem_indices_of_mem, insert]
|
||||
instantiate only [= getElem?_neg, = getElem?_pos]
|
||||
cases #f590
|
||||
next =>
|
||||
|
|
@ -249,8 +249,7 @@ info: Try this:
|
|||
instantiate only [= Array.getElem_set]
|
||||
next =>
|
||||
instantiate only
|
||||
instantiate approx [= HashMap.getElem_insert, = Array.size_push, size, = Array.getElem_push,
|
||||
= HashMap.contains_insert, = HashMap.mem_insert, = Array.size_push]
|
||||
instantiate only [= Array.getElem_push, size, = HashMap.getElem_insert, = HashMap.mem_insert]
|
||||
next =>
|
||||
instantiate only [= getElem_def, = mem_indices_of_mem]
|
||||
instantiate only [usr getElem_indices_lt]
|
||||
|
|
@ -272,7 +271,8 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
|
|||
example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
|
||||
(m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by
|
||||
grind =>
|
||||
instantiate only [= getElem_def, insert]
|
||||
-- **TODO**: Check approx here
|
||||
instantiate approx [= getElem_def, = mem_indices_of_mem, insert]
|
||||
instantiate only [= getElem?_neg, = getElem?_pos]
|
||||
cases #f590
|
||||
next =>
|
||||
|
|
@ -282,9 +282,8 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
|
|||
instantiate only [= Array.getElem_set]
|
||||
next =>
|
||||
instantiate only
|
||||
-- **TODO**: Investigate why we need `approx` here
|
||||
instantiate approx [= HashMap.getElem_insert, = Array.size_push, size, = Array.getElem_push,
|
||||
= HashMap.contains_insert, = HashMap.mem_insert, = Array.size_push]
|
||||
instantiate only [= Array.getElem_push, size, = HashMap.getElem_insert,
|
||||
= HashMap.mem_insert]
|
||||
next =>
|
||||
instantiate only [= getElem_def, = mem_indices_of_mem]
|
||||
instantiate only [usr getElem_indices_lt]
|
||||
|
|
|
|||
128
tests/lean/run/grind_order_issue.lean
Normal file
128
tests/lean/run/grind_order_issue.lean
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
import Std.Data.HashMap
|
||||
set_option warn.sorry false
|
||||
macro_rules | `(tactic| get_elem_tactic_extensible) => `(tactic| grind)
|
||||
|
||||
open Std
|
||||
|
||||
structure IndexMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
|
||||
private indices : HashMap α Nat
|
||||
private keys : Array α
|
||||
private values : Array β
|
||||
private size_keys' : keys.size = values.size := by grind
|
||||
private WF : ∀ (i : Nat) (a : α), keys[i]? = some a ↔ indices[a]? = some i := by grind
|
||||
|
||||
namespace IndexMap
|
||||
|
||||
variable {α : Type u} {β : Type v} [BEq α] [Hashable α]
|
||||
variable {m : IndexMap α β} {a : α} {b : β} {i : Nat}
|
||||
|
||||
@[inline] def size (m : IndexMap α β) : Nat :=
|
||||
m.values.size
|
||||
|
||||
@[local grind =] private theorem size_keys : m.keys.size = m.size := m.size_keys'
|
||||
|
||||
def emptyWithCapacity (capacity := 8) : IndexMap α β where
|
||||
indices := HashMap.emptyWithCapacity capacity
|
||||
keys := Array.emptyWithCapacity capacity
|
||||
values := Array.emptyWithCapacity capacity
|
||||
|
||||
instance : EmptyCollection (IndexMap α β) where
|
||||
emptyCollection := emptyWithCapacity
|
||||
|
||||
instance : Inhabited (IndexMap α β) where
|
||||
default := ∅
|
||||
|
||||
@[inline] def contains (m : IndexMap α β)
|
||||
(a : α) : Bool :=
|
||||
m.indices.contains a
|
||||
|
||||
instance : Membership α (IndexMap α β) where
|
||||
mem m a := a ∈ m.indices
|
||||
|
||||
instance {m : IndexMap α β} {a : α} : Decidable (a ∈ m) :=
|
||||
inferInstanceAs (Decidable (a ∈ m.indices))
|
||||
|
||||
@[local grind =] private theorem mem_indices_of_mem {m : IndexMap α β} {a : α} :
|
||||
a ∈ m ↔ a ∈ m.indices := Iff.rfl
|
||||
|
||||
@[inline] def findIdx? (m : IndexMap α β) (a : α) : Option Nat := m.indices[a]?
|
||||
|
||||
@[inline] def findIdx (m : IndexMap α β) (a : α) (h : a ∈ m := by get_elem_tactic) : Nat := m.indices[a]
|
||||
|
||||
@[inline] def getIdx? (m : IndexMap α β) (i : Nat) : Option β := m.values[i]?
|
||||
|
||||
@[inline] def getIdx (m : IndexMap α β) (i : Nat) (h : i < m.size := by get_elem_tactic) : β :=
|
||||
m.values[i]
|
||||
|
||||
variable [LawfulBEq α] [LawfulHashable α]
|
||||
|
||||
attribute [local grind _=_] IndexMap.WF
|
||||
|
||||
private theorem getElem_indices_lt {h : a ∈ m} : m.indices[a] < m.size := by
|
||||
have : m.indices[a]? = some m.indices[a] := by grind
|
||||
grind
|
||||
|
||||
grind_pattern getElem_indices_lt => m.indices[a]
|
||||
|
||||
attribute [local grind] size
|
||||
|
||||
instance : GetElem? (IndexMap α β) α β (fun m a => a ∈ m) where
|
||||
getElem m a h := m.values[m.indices[a]'h]
|
||||
getElem? m a := m.indices[a]?.bind (fun i => (m.values[i]?))
|
||||
getElem! m a := m.indices[a]?.bind (fun i => (m.values[i]?)) |>.getD default
|
||||
|
||||
@[local grind =] private theorem getElem_def (m : IndexMap α β) (a : α) (h : a ∈ m) : m[a] = m.values[m.indices[a]'h] := rfl
|
||||
@[local grind =] private theorem getElem?_def (m : IndexMap α β) (a : α) :
|
||||
m[a]? = m.indices[a]?.bind (fun i => (m.values[i]?)) := rfl
|
||||
@[local grind =] private theorem getElem!_def [Inhabited β] (m : IndexMap α β) (a : α) :
|
||||
m[a]! = (m.indices[a]?.bind (fun i => (m.values[i]?))).getD default := rfl
|
||||
|
||||
instance : LawfulGetElem (IndexMap α β) α β (fun m a => a ∈ m) where
|
||||
getElem?_def := by grind
|
||||
getElem!_def := by grind
|
||||
|
||||
@[inline] def insert [LawfulBEq α] (m : IndexMap α β) (a : α) (b : β) :
|
||||
IndexMap α β :=
|
||||
match h : m.indices[a]? with
|
||||
| some i =>
|
||||
{ indices := m.indices
|
||||
keys := m.keys.set i a
|
||||
values := m.values.set i b }
|
||||
| none =>
|
||||
{ indices := m.indices.insert a m.size
|
||||
keys := m.keys.push a
|
||||
values := m.values.push b }
|
||||
|
||||
/-! ### Verification theorems -/
|
||||
|
||||
attribute [local grind] getIdx findIdx insert
|
||||
|
||||
example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
|
||||
(m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by
|
||||
grind -offset -ring -linarith -cutsat =>
|
||||
instantiate only [= getElem_def, insert]
|
||||
cases #f590
|
||||
next =>
|
||||
cases #ffdf
|
||||
next => sorry
|
||||
next =>
|
||||
instantiate only
|
||||
instantiate only [= HashMap.getElem_insert]
|
||||
instantiate only [= size]
|
||||
instantiate only [= Array.getElem_push, = mem_indices_of_mem]
|
||||
next => sorry
|
||||
|
||||
example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
|
||||
(m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by
|
||||
grind -offset -ring -linarith -cutsat =>
|
||||
instantiate only [= getElem_def, insert]
|
||||
cases #f590
|
||||
next =>
|
||||
cases #ffdf
|
||||
next => sorry
|
||||
next =>
|
||||
instantiate only
|
||||
instantiate only [= HashMap.getElem_insert]
|
||||
instantiate only [= size]
|
||||
instantiate only [= mem_indices_of_mem, = Array.getElem_push]
|
||||
next => sorry
|
||||
Loading…
Add table
Reference in a new issue