From e5a6901161541aaa23fbb40ac014909685acc4e9 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 1 Nov 2025 08:37:17 -0700 Subject: [PATCH] feat: `Nat` equality propagation in `grind order` (#11049) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements equality propagation for `Nat` in `grind order`. `grind order` supports offset equalities for rings, but it has an adapter for `Nat`. Example: ```lean example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ a → f (1 + a) = f (1 + b + 1) := by grind -offset -mbtc -lia -linarith (splits := 0) ``` --- src/Init/Grind/Order.lean | 4 ++ src/Lean/Meta/Tactic/Grind/Order/Assert.lean | 54 +++++++++++++++++++- tests/lean/run/grind_10885.lean | 10 +--- tests/lean/run/grind_order_eq.lean | 9 ++++ 4 files changed, 66 insertions(+), 11 deletions(-) diff --git a/src/Init/Grind/Order.lean b/src/Init/Grind/Order.lean index 5667b2126e..d7ed1a0f90 100644 --- a/src/Init/Grind/Order.lean +++ b/src/Init/Grind/Order.lean @@ -65,6 +65,10 @@ theorem le_of_offset_eq_2_k {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsP rw [Ring.intCast_neg, Semiring.add_assoc, Semiring.add_comm (α := α) k, Ring.neg_add_cancel, Semiring.add_zero] apply Std.IsPreorder.le_refl +theorem nat_eq (a b : Nat) (x y : Int) : NatCast.natCast a = x → NatCast.natCast b = y → x = y → a = b := by + intro _ _; subst x y; intro h + exact Int.natCast_inj.mp h + theorem le_of_not_le {α} [LE α] [Std.IsLinearPreorder α] {a b : α} : ¬ a ≤ b → b ≤ a := by intro h diff --git a/src/Lean/Meta/Tactic/Grind/Order/Assert.lean b/src/Lean/Meta/Tactic/Grind/Order/Assert.lean index 0b750d2676..8291fc448c 100644 --- a/src/Lean/Meta/Tactic/Grind/Order/Assert.lean +++ b/src/Lean/Meta/Tactic/Grind/Order/Assert.lean @@ -129,11 +129,42 @@ def propagatePending : OrderM Unit := do | .eq u v => let ue ← getExpr u let ve ← getExpr v - unless (← isEqv ue ve) do + if (← alreadyInternalized ue <&&> alreadyInternalized ve) then + unless (← isEqv ue ve) do + let huv ← mkProofForPath u v + let hvu ← mkProofForPath v u + let h ← mkEqProofOfLeOfLe ue ve huv hvu + pushEq ue ve h + -- Checks whether `ue` and `ve` are auxiliary terms + let some (ue', h₁) ← getOriginal? ue | continue + let some (ve', h₂) ← getOriginal? ve | continue + if (← alreadyInternalized ue' <&&> alreadyInternalized ve') then + unless (← isEqv ue' ve') do let huv ← mkProofForPath u v let hvu ← mkProofForPath v u let h ← mkEqProofOfLeOfLe ue ve huv hvu - pushEq ue ve h + /- + We have + - `h₁ : ↑ue' = ue` + - `h₂ : ↑ve' = ve` + - `h : ue = ve` + -/ + pushEq ue' ve' <| mkApp7 (mkConst ``Grind.Order.nat_eq) ue' ve' ue ve h₁ h₂ h +where + /-- + If `e` is an auxiliary term used to represent some term `a`, returns + `some (a, h)` s.t. `h : ↑a = e` + **Note**: We currently only support `Nat`. Thus `↑a` is actually + `NatCast.natCast a`. If we decide to support arbitrary semirings + in this module, we must adjust this code. + -/ + getOriginal? (e : Expr) : GoalM (Option (Expr × Expr)) := do + if let some r := (← get').termMapInv.find? { expr := e } then + return some r + else + let_expr NatCast.natCast _ _ a := e | return none + let h ← mkEqRefl e + return some (a, h) /-- Returns `true` if `e` is already `True` in the `grind` core. @@ -190,6 +221,7 @@ Traverses the constraints `c` (representing an expression `e`) s.t. /-- Equality propagation. -/ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do + if u == v then return () if (← isPartialOrder) then if !k.isZero then return () let some k' ← getDist? v u | return () @@ -199,6 +231,24 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do if (← alreadyInternalized ue <&&> alreadyInternalized ve) then if (← isEqv ue ve) then return () pushToPropagate <| .eq u v + else + /- + Check whether `ue` and `ve` are auxiliary terms used to encode `Nat` terms. + **Note**: `getOriginal?` is currently hard coded to the `Nat` case since + it is the only type we map to rings. If in the future, we want to support + arbitrary `Semiring`s, we must adjust this code. + -/ + let some ue ← getOriginal? ue | return () + let some ve ← getOriginal? ve | return () + if (← alreadyInternalized ue <&&> alreadyInternalized ve) then + if (← isEqv ue ve) then return () + pushToPropagate <| .eq u v +where + getOriginal? (e : Expr) : GoalM (Option Expr) := do + let_expr NatCast.natCast _ _ a := e + | let some (a, _) := (← get').termMapInv.find? { expr := e } | return none + return some a + return some a /-- Finds constrains and equalities to be propagated. -/ def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do diff --git a/tests/lean/run/grind_10885.lean b/tests/lean/run/grind_10885.lean index de268ae179..f0b0544d6f 100644 --- a/tests/lean/run/grind_10885.lean +++ b/tests/lean/run/grind_10885.lean @@ -1,12 +1,4 @@ example {a b : Nat} (ha : 1 ≤ a) : (a - 1 + 1) * b = a * b := by grind -/-- -info: Try these: - [apply] ⏎ - mbtc - cases #9501 - [apply] finish only [#9501] --/ -#guard_msgs in example {a b : Nat} (ha : 1 ≤ a) : (a - 1 + 1) * b = a * b := by - grind => finish? -- mbtc was applied consider nonlinear `*` + grind => done diff --git a/tests/lean/run/grind_order_eq.lean b/tests/lean/run/grind_order_eq.lean index 96904e6467..1321f792e2 100644 --- a/tests/lean/run/grind_order_eq.lean +++ b/tests/lean/run/grind_order_eq.lean @@ -9,3 +9,12 @@ example [CommRing α] [LE α] [LT α] [LawfulOrderLT α] [IsPartialOrder α] [Or example (a b : Int) (f : Int → Int) : a ≤ b + 1 → b ≤ a - 1 → f a = f (2 + b - 1) := by grind -mbtc -lia -linarith (splits := 0) + +example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ a → f a = f (1 + b + 0) := by + grind -offset -mbtc -lia -linarith (splits := 0) + +example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ c → c ≤ a → f a = f c := by + grind -offset -mbtc -lia -linarith (splits := 0) + +example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ a → f (1 + a) = f (1 + b + 1) := by + grind -offset -mbtc -lia -linarith (splits := 0)