From a2ceebe2001c65045f043ae74075f6f4f84c8591 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 13 Dec 2025 15:32:34 +0100 Subject: [PATCH] feat: semiring `*` propagators in `grind` (#11653) This PR adds propagation rules corresponding to the `Semiring` normalization rules introduced in #11628. The new rules apply only to non-commutative semirings, since support for them in `grind` is limited. The normalization rules introduced unexpected behavior in Mathlib because they neutralize parameters such as `one_mul`: any theorem instance associated with such a parameter is reduced to `True` by the normalizer. --- src/Init/Grind/Lemmas.lean | 15 +++- .../Meta/Tactic/Grind/Arith/Propagate.lean | 66 +++++++++++++++- .../run/grind_semiring_norm_regression.lean | 77 +++++++++++++++++++ 3 files changed, 153 insertions(+), 5 deletions(-) create mode 100644 tests/lean/run/grind_semiring_norm_regression.lean diff --git a/src/Init/Grind/Lemmas.lean b/src/Init/Grind/Lemmas.lean index 1d609c2475..057988c09a 100644 --- a/src/Init/Grind/Lemmas.lean +++ b/src/Init/Grind/Lemmas.lean @@ -4,13 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.ByCases public import Init.Grind.Util - +public import Init.Grind.Ring.Basic public section - namespace Lean.Grind theorem rfl_true : true = true := @@ -193,4 +191,15 @@ theorem Nat.or_congr {a b : Nat} {k₁ k₂ k : Nat} (h₁ : a = k₁) (h₂ : b theorem Nat.shiftLeft_congr {a b : Nat} {k₁ k₂ k : Nat} (h₁ : a = k₁) (h₂ : b = k₂) : k == k₁ <<< k₂ → a <<< b = k := by simp_all theorem Nat.shiftRight_congr {a b : Nat} {k₁ k₂ k : Nat} (h₁ : a = k₁) (h₂ : b = k₂) : k == k₁ >>> k₂ → a >>> b = k := by simp_all +/-! Semiring propagators -/ + +theorem Semiring.one_mul_congr {α} [Semiring α] {a b : α} (h : a = 1) : a*b = b := by + simp [h, Semiring.one_mul] +theorem Semiring.zero_mul_congr {α} [Semiring α] {a b : α} (h : a = 0) : a*b = 0 := by + simp [h, Semiring.zero_mul] +theorem Semiring.mul_one_congr {α} [Semiring α] {a b : α} (h : b = 1) : a*b = a := by + simp [h, Semiring.mul_one] +theorem Semiring.mul_zero_congr {α} [Semiring α] {a b : α} (h : b = 0) : a*b = 0 := by + simp [h, Semiring.mul_zero] + end Lean.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Propagate.lean b/src/Lean/Meta/Tactic/Grind/Arith/Propagate.lean index dc27dcfc8a..00a7f62138 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Propagate.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Propagate.lean @@ -8,8 +8,11 @@ prelude public import Lean.Meta.Tactic.Grind.Types import Init.Grind import Lean.Meta.Tactic.Grind.PropagatorAttr +import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId +import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommRingM +import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommSemiringM public section -namespace Lean.Meta.Grind +namespace Lean.Meta.Grind.Arith /-! This file defines propagators for `Nat` operators that have simprocs associated with them, but do not @@ -62,4 +65,63 @@ builtin_grind_propagator propagateNatShiftLeft ↑HShiftLeft.hShiftLeft := builtin_grind_propagator propagateNatShiftRight ↑HShiftRight.hShiftRight := propagateNatBinOp ``HShiftRight.hShiftRight ``Grind.Nat.shiftRight_congr (· >>> ·) -end Lean.Meta.Grind +private def supportedSemiring : Std.HashSet Name := + [``Nat, ``Int, ``Rat, ``BitVec, ``UInt8, ``UInt16, ``UInt32, ``Int64, ``Int8, ``Int16, ``Int32, ``Int64].foldl (init := {}) (·.insert ·) + +private def isSupportedSemiringQuick (type : Expr) : Bool := Id.run do + let .const declName _ := type.getAppFn | return false + return supportedSemiring.contains declName + +open CommRing in +/-- +Return `some inst` where `inst : Semiring type` if `type` is a semiring +that does not have good support in `grind ring`. That is, `grind ring` +supports only normalization, but not equational reasoning. +See comment at `propagateMul`. +-/ +private def isUnsupportedSemiring? (type : Expr) : GoalM (Option Expr) := do + if isSupportedSemiringQuick type then return none + if (← getCommRingId? type).isSome then return none + if (← getCommSemiringId? type).isSome then return none + if let some id ← getNonCommRingId? type then + let inst ← NonCommRingM.run id do return (← getRing).semiringInst + return some inst + if let some id ← getNonCommSemiringId? type then + let inst ← NonCommSemiringM.run id do return (← getSemiring).semiringInst + return some inst + return none + +private def isOfNat? (a : Expr) : MetaM (Option Nat) := do + let_expr OfNat.ofNat _ n _ := a | return none + getNatValue? n + +/-- +Propagator for the `0*a`, `1*a`, `a*0`, `a*1` for semirings that do not have good support in +`grind ring`. We need this propagator because have normalization rules for them, and users +were surprised when using `grind [zero_mul]` did not have any effect. In this scenario, +`grind` was correctly instantiating `zero_mul : 0*a = 0`, but the normalizer reduces the +instance to `True`. + +Alternative approach: We improve the support for equality reasoning for non-commutative rings +and semirings in `grind ring`. For example, we could just replace equalities and keep renormalizing. +If we implement this feature, this propagator can be deleted. +-/ +builtin_grind_propagator propagateMul ↑HMul.hMul := fun e => do + let_expr f@HMul.hMul α₁ α₂ α₃ _ a b := e | return () + let some semiringInst ← isUnsupportedSemiring? α₁ | return () + unless isSameExpr α₁ α₂ && isSameExpr α₁ α₃ do return () + let u :: _ := f.constLevels! | return () + let aRoot ← getRoot a + let bRoot ← getRoot b + if let some n ← isOfNat? aRoot then + if n == 0 then + pushEq e aRoot <| mkApp5 (mkConst ``Grind.Semiring.zero_mul_congr [u]) α₁ semiringInst a b (← mkEqProof a aRoot) + else if n == 1 then + pushEq e b <| mkApp5 (mkConst ``Grind.Semiring.one_mul_congr [u]) α₁ semiringInst a b (← mkEqProof a aRoot) + else if let some n ← isOfNat? bRoot then + if n == 0 then + pushEq e bRoot <| mkApp5 (mkConst ``Grind.Semiring.mul_zero_congr [u]) α₁ semiringInst a b (← mkEqProof b bRoot) + else if n == 1 then + pushEq e a <| mkApp5 (mkConst ``Grind.Semiring.mul_one_congr [u]) α₁ semiringInst a b (← mkEqProof b bRoot) + +end Lean.Meta.Grind.Arith diff --git a/tests/lean/run/grind_semiring_norm_regression.lean b/tests/lean/run/grind_semiring_norm_regression.lean new file mode 100644 index 0000000000..aa09a776fb --- /dev/null +++ b/tests/lean/run/grind_semiring_norm_regression.lean @@ -0,0 +1,77 @@ +section Mathlib.Data.Nat.Init + +namespace Nat + +class AtLeastTwo (n : Nat) : Prop where + prop : 2 ≤ n + +instance (n : Nat) [NeZero n] : (n + 1).AtLeastTwo := + ⟨add_le_add (one_le_iff_ne_zero.mpr (NeZero.ne n)) (Nat.le_refl 1)⟩ + +end Nat + +end Mathlib.Data.Nat.Init + +section Mathlib.Data.Nat.Cast.Defs + +instance {R : Type} {n : Nat} [NatCast R] [Nat.AtLeastTwo n] : + OfNat R n where + ofNat := n.cast + +end Mathlib.Data.Nat.Cast.Defs +section Mathlib.Algebra.GroupWithZero.Defs + +class MulZeroClass (α : Type) extends Mul α, Zero α where + mul_zero : ∀ a : α, a * 0 = 0 + +end Mathlib.Algebra.GroupWithZero.Defs + +section Mathlib.Algebra.Ring.Defs + +class Semiring (α : Type) extends + One α, NatCast α, Add α, Mul α, MulZeroClass α + +end Mathlib.Algebra.Ring.Defs + +section Mathlib.Algebra.Ring.GrindInstances + +instance Semiring.toGrindSemiring (α : Type) [s : Semiring α] : + Lean.Grind.Semiring α := + { s with + nsmul := sorry + npow := sorry + ofNat | 0 | 1 | n + 2 => inferInstance + natCast := sorry + add_zero := sorry + mul_one := sorry + zero_mul := sorry + pow_zero := sorry + pow_succ := sorry + ofNat_eq_natCast := sorry + ofNat_succ := sorry + nsmul_eq_natCast_mul := sorry + add_comm := sorry + left_distrib := sorry + right_distrib := sorry + mul_zero := sorry + add_assoc := sorry + mul_assoc := sorry + one_mul := sorry } + +end Mathlib.Algebra.Ring.GrindInstances + +section Mathlib.Algebra.Polynomial.Coeff + +theorem coeff_mul_X_pow {R : Type} [Semiring R] (p : R) (n d : Nat) : + ∀ b, b.1 + b.2 = d + n → b ≠ (d, n) → p * (if n = b.2 then 1 else 0) = 0 := by + grind only [MulZeroClass.mul_zero] + +theorem coeff_mul_X_pow' {R : Type} [Semiring R] (p : R) (n d : Nat) : + ∀ b, b.1 + b.2 = d + n → b ≠ (d, n) → p * (if n = b.2 then 1 else 0) = 0 := by + grind only + +example [Semiring α] (a b c : α) : b = 0 → a * b * c = 0 := by + grind only + +example [Semiring α] (a b c : α) : c = 1 → a = 1 → a * b * c = b := by + grind only