feat: offset equalities in grind (#6645)
This PR implements support for offset equality constraints in the
`grind` tactic and exhaustive equality propagation for them. The `grind`
tactic can now solve problems such as the following:
```lean
example (f : Nat → Nat) (a b c d e : Nat) :
f (a + 3) = b →
f (c + 1) = d →
c ≤ a + 2 →
a + 1 ≤ e →
e < c →
b = d := by
grind
```
This commit is contained in:
parent
3da7f70014
commit
563d5e8bcf
13 changed files with 153 additions and 44 deletions
|
|
@ -80,4 +80,12 @@ theorem Nat.ro_eq_false_of_lo (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true →
|
|||
theorem Nat.lo_eq_false_of_ro (u v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ v + k₁ → (v + k₂ ≤ u) = False := by
|
||||
simp [isLt]; omega
|
||||
|
||||
/-!
|
||||
Helper theorems for equality propagation
|
||||
-/
|
||||
|
||||
theorem Nat.le_of_eq_1 (u v : Nat) : u = v → u ≤ v := by omega
|
||||
theorem Nat.le_of_eq_2 (u v : Nat) : u = v → v ≤ u := by omega
|
||||
theorem Nat.eq_of_le_of_le (u v : Nat) : u ≤ v → v ≤ u → u = v := by omega
|
||||
|
||||
end Lean.Grind
|
||||
|
|
|
|||
|
|
@ -48,6 +48,9 @@ builtin_initialize registerTraceClass `grind.offset.dist
|
|||
builtin_initialize registerTraceClass `grind.offset.internalize
|
||||
builtin_initialize registerTraceClass `grind.offset.internalize.term (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.offset.propagate
|
||||
builtin_initialize registerTraceClass `grind.offset.eq
|
||||
builtin_initialize registerTraceClass `grind.offset.eq.to (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.offset.eq.from (inherited := true)
|
||||
|
||||
/-! Trace options for `grind` developers -/
|
||||
builtin_initialize registerTraceClass `grind.debug
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import Lean.Meta.Tactic.Grind.Arith.Offset
|
|||
|
||||
namespace Lean.Meta.Grind.Arith
|
||||
|
||||
def internalize (e : Expr) : GoalM Unit := do
|
||||
Offset.internalizeCnstr e
|
||||
def internalize (e : Expr) (parent : Expr) : GoalM Unit := do
|
||||
Offset.internalize e parent
|
||||
|
||||
end Lean.Meta.Grind.Arith
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
|||
prelude
|
||||
import Lean.Meta.Basic
|
||||
import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.Util
|
||||
|
||||
namespace Lean.Meta.Grind.Arith.Offset
|
||||
/-- Construct a model that statisfies all offset constraints -/
|
||||
|
|
@ -33,7 +34,13 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Nat)) := do
|
|||
for u in [:nodes.size] do
|
||||
let some val := pre[u]! | unreachable!
|
||||
let val := (val - min).toNat
|
||||
r := r.push (nodes[u]!, val)
|
||||
let e := nodes[u]!
|
||||
/-
|
||||
We should not include the assignment for auxiliary offset terms since
|
||||
they do not provide any additional information.
|
||||
-/
|
||||
unless isNatOffset? e |>.isSome do
|
||||
r := r.push (e, val)
|
||||
return r
|
||||
|
||||
end Lean.Meta.Grind.Arith.Offset
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ def mkNode (expr : Expr) : GoalM NodeId := do
|
|||
targets := s.targets.push {}
|
||||
proofs := s.proofs.push {}
|
||||
}
|
||||
markAsOffsetTerm expr
|
||||
return nodeId
|
||||
|
||||
private def getExpr (u : NodeId) : GoalM Expr := do
|
||||
|
|
@ -59,6 +60,11 @@ private def getDist? (u v : NodeId) : GoalM (Option Int) := do
|
|||
private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do
|
||||
return (← get').proofs[u]!.find? v
|
||||
|
||||
private def getNodeId (e : Expr) : GoalM NodeId := do
|
||||
let some nodeId := (← get').nodeMap.find? { expr := e }
|
||||
| throwError "internal `grind` error, term has not been internalized by offset module{indentExpr e}"
|
||||
return nodeId
|
||||
|
||||
/--
|
||||
Returns a proof for `u + k ≤ v` (or `u ≤ v + k`) where `k` is the
|
||||
shortest path between `u` and `v`.
|
||||
|
|
@ -160,10 +166,24 @@ private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → GoalM B
|
|||
return !(← f c e)
|
||||
modify' fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }
|
||||
|
||||
/-- Equality propagation. -/
|
||||
private def propagateEq (u v : NodeId) (k : Int) : GoalM Unit := do
|
||||
if k != 0 then return ()
|
||||
let some k' ← getDist? v u | return ()
|
||||
if k' != 0 then return ()
|
||||
let ue ← getExpr u
|
||||
let ve ← getExpr v
|
||||
if (← isEqv ue ve) then return ()
|
||||
let huv ← mkProofForPath u v
|
||||
let hvu ← mkProofForPath v u
|
||||
trace[grind.offset.eq.from] "{ue}, {ve}"
|
||||
pushEq ue ve <| mkApp4 (mkConst ``Grind.Nat.eq_of_le_of_le) ue ve huv hvu
|
||||
|
||||
/-- Performs constraint propagation. -/
|
||||
private def propagateAll (u v : NodeId) (k : Int) : GoalM Unit := do
|
||||
updateCnstrsOf u v fun c e => return !(← propagateTrue u v k c e)
|
||||
updateCnstrsOf v u fun c e => return !(← propagateFalse u v k c e)
|
||||
propagateEq u v k
|
||||
|
||||
/--
|
||||
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
|
||||
|
|
@ -203,8 +223,7 @@ where
|
|||
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
|
||||
updateIfShorter i j (k₁+k+k₂) v
|
||||
|
||||
def internalizeCnstr (e : Expr) : GoalM Unit := do
|
||||
let some c := isNatOffsetCnstr? e | return ()
|
||||
private def internalizeCnstr (e : Expr) (c : Cnstr Expr) : GoalM Unit := do
|
||||
let u ← mkNode c.u
|
||||
let v ← mkNode c.v
|
||||
let c := { c with u, v }
|
||||
|
|
@ -222,6 +241,29 @@ def internalizeCnstr (e : Expr) : GoalM Unit := do
|
|||
s.cnstrsOf.insert (u, v) cs
|
||||
}
|
||||
|
||||
def internalize (e : Expr) (parent : Expr) : GoalM Unit := do
|
||||
if let some c := isNatOffsetCnstr? e then
|
||||
internalizeCnstr e c
|
||||
else if let some (b, k) := isNatOffset? e then
|
||||
if (isNatOffsetCnstr? parent).isSome then return ()
|
||||
-- `e` is of the form `b + k`
|
||||
let u ← mkNode e
|
||||
let v ← mkNode b
|
||||
-- `u = v + k`. So, we add edges for `u ≤ v + k` and `v + k ≤ u`.
|
||||
let h := mkApp (mkConst ``Nat.le_refl) e
|
||||
addEdge u v k h
|
||||
addEdge v u (-k) h
|
||||
|
||||
@[export lean_process_new_offset_eq]
|
||||
def processNewOffsetEqImpl (a b : Expr) : GoalM Unit := do
|
||||
unless isSameExpr a b do
|
||||
trace[grind.offset.eq.to] "{a}, {b}"
|
||||
let u ← getNodeId a
|
||||
let v ← getNodeId b
|
||||
let h ← mkEqProof a b
|
||||
addEdge u v 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_1) a b h
|
||||
addEdge v u 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_2) a b h
|
||||
|
||||
def traceDists : GoalM Unit := do
|
||||
let s ← get'
|
||||
for u in [:s.targets.size], es in s.targets.toArray do
|
||||
|
|
@ -231,13 +273,12 @@ def traceDists : GoalM Unit := do
|
|||
def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do
|
||||
let u := (← get').nodes[c.u]!
|
||||
let v := (← get').nodes[c.v]!
|
||||
let mk := if c.le then mkNatLE else mkNatEq
|
||||
if c.k == 0 then
|
||||
return mk u v
|
||||
return mkNatLE u v
|
||||
else if c.k < 0 then
|
||||
return mk (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
|
||||
return mkNatLE (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
|
||||
else
|
||||
return mk u (mkNatAdd v (Lean.toExpr c.k.toNat))
|
||||
return mkNatLE u (mkNatAdd v (Lean.toExpr c.k.toNat))
|
||||
|
||||
def checkInvariants : GoalM Unit := do
|
||||
let s ← get'
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ def mkOfNegEqFalse (nodes : PArray Expr) (c : Cnstr NodeId) (h : Expr) : Expr :=
|
|||
let v := nodes[c.v]!
|
||||
if c.k == 0 then
|
||||
mkApp3 (mkConst ``Nat.of_le_eq_false) u v h
|
||||
else if c.k == -1 && c.le then
|
||||
else if c.k == -1 then
|
||||
mkApp3 (mkConst ``Nat.of_lo_eq_false_1) u v h
|
||||
else if c.k < 0 then
|
||||
mkApp4 (mkConst ``Nat.of_lo_eq_false) u v (toExprN (-c.k)) h
|
||||
|
|
|
|||
|
|
@ -50,23 +50,19 @@ structure Offset.Cnstr (α : Type) where
|
|||
u : α
|
||||
v : α
|
||||
k : Int := 0
|
||||
le : Bool := true
|
||||
deriving Inhabited
|
||||
|
||||
def Offset.Cnstr.neg : Cnstr α → Cnstr α
|
||||
| { u, v, k, le } => { u := v, v := u, le, k := -k - 1 }
|
||||
| { u, v, k } => { u := v, v := u, k := -k - 1 }
|
||||
|
||||
example (c : Offset.Cnstr α) : c.neg.neg = c := by
|
||||
cases c; simp [Offset.Cnstr.neg]; omega
|
||||
|
||||
def Offset.toMessageData [inst : ToMessageData α] (c : Offset.Cnstr α) : MessageData :=
|
||||
match c.k, c.le with
|
||||
| .ofNat 0, true => m!"{c.u} ≤ {c.v}"
|
||||
| .ofNat 0, false => m!"{c.u} = {c.v}"
|
||||
| .ofNat k, true => m!"{c.u} ≤ {c.v} + {k}"
|
||||
| .ofNat k, false => m!"{c.u} = {c.v} + {k}"
|
||||
| .negSucc k, true => m!"{c.u} + {k + 1} ≤ {c.v}"
|
||||
| .negSucc k, false => m!"{c.u} + {k + 1} = {c.v}"
|
||||
match c.k with
|
||||
| .ofNat 0 => m!"{c.u} ≤ {c.v}"
|
||||
| .ofNat k => m!"{c.u} ≤ {c.v} + {k}"
|
||||
| .negSucc k => m!"{c.u} + {k + 1} ≤ {c.v}"
|
||||
|
||||
instance : ToMessageData (Offset.Cnstr Expr) where
|
||||
toMessageData c := Offset.toMessageData c
|
||||
|
|
@ -74,16 +70,15 @@ instance : ToMessageData (Offset.Cnstr Expr) where
|
|||
/-- Returns `some cnstr` if `e` is offset constraint. -/
|
||||
def isNatOffsetCnstr? (e : Expr) : Option (Offset.Cnstr Expr) :=
|
||||
match_expr e with
|
||||
| LE.le _ inst a b => if isInstLENat inst then go a b true else none
|
||||
| Eq α a b => if isNatType α then go a b false else none
|
||||
| LE.le _ inst a b => if isInstLENat inst then go a b else none
|
||||
| _ => none
|
||||
where
|
||||
go (u v : Expr) (le : Bool) :=
|
||||
go (u v : Expr) :=
|
||||
if let some (u, k) := isNatOffset? u then
|
||||
some { u, k := - k, v, le }
|
||||
some { u, k := - k, v }
|
||||
else if let some (v, k) := isNatOffset? v then
|
||||
some { u, v, k := k, le }
|
||||
some { u, v, k := k }
|
||||
else
|
||||
some { u, v, le }
|
||||
some { u, v }
|
||||
|
||||
end Lean.Meta.Grind.Arith
|
||||
|
|
|
|||
|
|
@ -86,6 +86,18 @@ private partial def updateMT (root : Expr) : GoalM Unit := do
|
|||
setENode parent { node with mt := gmt }
|
||||
updateMT parent
|
||||
|
||||
/--
|
||||
Helper function for combining `ENode.offset?` fields and propagating an equality
|
||||
to the offset constraint module.
|
||||
-/
|
||||
private def propagateOffsetEq (root : Expr) (roofOffset? otherOffset? : Option Expr) : GoalM Unit := do
|
||||
let some otherOffset := otherOffset? | return ()
|
||||
if let some rootOffset := roofOffset? then
|
||||
processNewOffsetEq rootOffset otherOffset
|
||||
else
|
||||
let n ← getENode root
|
||||
setENode root { n with offset? := otherOffset? }
|
||||
|
||||
private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
|
||||
let lhsNode ← getENode lhs
|
||||
let rhsNode ← getENode rhs
|
||||
|
|
@ -146,17 +158,18 @@ where
|
|||
next := rhsRoot.next
|
||||
}
|
||||
setENode rhsNode.root { rhsRoot with
|
||||
next := lhsRoot.next
|
||||
size := rhsRoot.size + lhsRoot.size
|
||||
next := lhsRoot.next
|
||||
size := rhsRoot.size + lhsRoot.size
|
||||
hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas
|
||||
heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs
|
||||
}
|
||||
copyParentsTo parents rhsNode.root
|
||||
unless (← isInconsistent) do
|
||||
updateMT rhsRoot.self
|
||||
propagateOffsetEq rhsNode.root rhsRoot.offset? lhsRoot.offset?
|
||||
unless (← isInconsistent) do
|
||||
for parent in parents do
|
||||
propagateUp parent
|
||||
unless (← isInconsistent) do
|
||||
updateMT rhsRoot.self
|
||||
|
||||
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
|
||||
traverseEqc lhs fun n =>
|
||||
|
|
|
|||
|
|
@ -98,6 +98,8 @@ private def pushCastHEqs (e : Expr) : GoalM Unit := do
|
|||
| f@Eq.recOn α a motive b h v => pushHEq e v (mkApp6 (mkConst ``Grind.eqRecOn_heq f.constLevels!) α a motive b h v)
|
||||
| _ => return ()
|
||||
|
||||
def noParent := mkBVar 0
|
||||
|
||||
mutual
|
||||
/-- Internalizes the nested ground terms in the given pattern. -/
|
||||
private partial def internalizePattern (pattern : Expr) (generation : Nat) : GoalM Expr := do
|
||||
|
|
@ -146,7 +148,7 @@ private partial def activateTheoremPatterns (fName : Name) (generation : Nat) :
|
|||
trace_goal[grind.ematch] "reinsert `{thm.origin.key}`"
|
||||
modify fun s => { s with thmMap := s.thmMap.insert thm }
|
||||
|
||||
partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
|
||||
partial def internalize (e : Expr) (generation : Nat) (parent : Expr := noParent) : GoalM Unit := do
|
||||
if (← alreadyInternalized e) then return ()
|
||||
trace_goal[grind.internalize] "{e}"
|
||||
match e with
|
||||
|
|
@ -157,10 +159,10 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
|
|||
| .forallE _ d b _ =>
|
||||
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
|
||||
if (← isProp d <&&> isProp e) then
|
||||
internalize d generation
|
||||
internalize d generation e
|
||||
registerParent e d
|
||||
unless b.hasLooseBVars do
|
||||
internalize b generation
|
||||
internalize b generation e
|
||||
registerParent e b
|
||||
propagateUp e
|
||||
| .lit .. | .const .. =>
|
||||
|
|
@ -182,22 +184,22 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
|
|||
-- We only internalize the proposition. We can skip the proof because of
|
||||
-- proof irrelevance
|
||||
let c := args[0]!
|
||||
internalize c generation
|
||||
internalize c generation e
|
||||
registerParent e c
|
||||
else
|
||||
if let .const fName _ := f then
|
||||
activateTheoremPatterns fName generation
|
||||
else
|
||||
internalize f generation
|
||||
internalize f generation e
|
||||
registerParent e f
|
||||
for h : i in [: args.size] do
|
||||
let arg := args[i]
|
||||
internalize arg generation
|
||||
internalize arg generation e
|
||||
registerParent e arg
|
||||
mkENode e generation
|
||||
addCongrTable e
|
||||
updateAppMap e
|
||||
Arith.internalize e
|
||||
Arith.internalize e parent
|
||||
propagateUp e
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -202,13 +202,19 @@ structure ENode where
|
|||
on heterogeneous equality.
|
||||
-/
|
||||
heqProofs : Bool := false
|
||||
/--
|
||||
Unique index used for pretty printing and debugging purposes.
|
||||
-/
|
||||
/-- Unique index used for pretty printing and debugging purposes. -/
|
||||
idx : Nat := 0
|
||||
/-- The generation in which this enode was created. -/
|
||||
generation : Nat := 0
|
||||
/-- Modification time -/
|
||||
mt : Nat := 0
|
||||
/--
|
||||
The `offset?` field is used to propagate equalities from the `grind` congruence closure module
|
||||
to the offset constraints module. When `grind` merges two equivalence classes, and both have
|
||||
an associated `offset?` set to `some e`, the equality is propagated. This field is
|
||||
assigned during the internalization of offset terms.
|
||||
-/
|
||||
offset? : Option Expr := none
|
||||
deriving Inhabited, Repr
|
||||
|
||||
def ENode.isCongrRoot (n : ENode) :=
|
||||
|
|
@ -643,6 +649,21 @@ def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do
|
|||
let interpreted ← isInterpreted e
|
||||
mkENodeCore e interpreted ctor generation
|
||||
|
||||
@[extern "lean_process_new_offset_eq"] -- forward definition
|
||||
opaque processNewOffsetEq (a b : Expr) : GoalM Unit
|
||||
|
||||
/--
|
||||
Marks `e` as a term of interest to the offset constraint module.
|
||||
If the root of `e`s equivalence class has already a term of interest,
|
||||
a new equality is propagated to the offset module.
|
||||
-/
|
||||
def markAsOffsetTerm (e : Expr) : GoalM Unit := do
|
||||
let n ← getRootENode e
|
||||
if let some e' := n.offset? then
|
||||
processNewOffsetEq e e'
|
||||
else
|
||||
setENode n.self { n with offset? := some e }
|
||||
|
||||
/-- Returns `true` is `e` is the root of its congruence class. -/
|
||||
def isCongrRoot (e : Expr) : GoalM Bool := do
|
||||
return (← getENode e).isCongrRoot
|
||||
|
|
|
|||
|
|
@ -83,8 +83,7 @@ info: [grind.assert] foo (c + 1) = a
|
|||
-/
|
||||
#guard_msgs (info) in
|
||||
example : foo (c + 1) = a → c = b + 1 → a = g (foo b) := by
|
||||
fail_if_success grind
|
||||
sorry
|
||||
grind
|
||||
|
||||
set_option trace.grind.assert false
|
||||
|
||||
|
|
|
|||
|
|
@ -352,3 +352,24 @@ example (p r : Prop) (a b : Nat) : (c + 1 ≤ a ↔ p) → (c + 2 ≤ a + 1 ↔
|
|||
set_option trace.grind.split true in
|
||||
example (p r : Prop) (a b : Nat) : (c + 5 ≤ a ↔ p) → (c + 4 ≤ a ↔ r) → a ≤ b → b ≤ c + 3 → ¬p ∧ ¬r := by
|
||||
grind (splits := 0)
|
||||
|
||||
example (a b c d: Nat) : a ≤ b → b + 2 = c → c < d → a + 2 < d := by
|
||||
grind
|
||||
|
||||
example (a b c : Nat) : a + 2 = b → b + 3 = c → a + 5 ≤ c := by
|
||||
grind
|
||||
|
||||
example (a b c : Nat) : a + 2 = b → c ≤ a + 2 → a + 2 ≤ c → c = b := by
|
||||
grind
|
||||
|
||||
example (a b c : Nat) : a + 2 = b → b + 3 = c → a + 5 = c := by
|
||||
grind
|
||||
|
||||
example (f : Nat → Nat) (a b c d e : Nat) :
|
||||
f (a + 3) = b →
|
||||
f (c + 1) = d →
|
||||
c ≤ a + 2 →
|
||||
a + 1 ≤ e →
|
||||
e < c →
|
||||
b = d := by
|
||||
grind
|
||||
|
|
|
|||
|
|
@ -75,9 +75,8 @@ x✝ : ¬g (i + 1) j ⋯ = i + j + 1
|
|||
[prop] ¬g (i + 1) j ⋯ = i + j + 1[eqc] True propositions
|
||||
[prop] j + 1 ≤ i[eqc] False propositions
|
||||
[prop] g (i + 1) j ⋯ = i + j + 1[offset] Assignment satisfying offset contraints
|
||||
[assign] j := 0
|
||||
[assign] i := 1
|
||||
[assign] g (i + 1) j ⋯ := 0
|
||||
[assign] j := 1
|
||||
[assign] i := 2
|
||||
[assign] i + j := 0
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue