From f8b0beeba9475c75c9780ef757c08f5227f87836 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 25 Oct 2025 20:52:28 -0700 Subject: [PATCH] fix: propagator for `a^(n+m)` in `grind` (#10964) This PR adds a propagator for `a^(n+m)` and removes its normalizer. This change was motivated by issue #10661 Closes #10661 --- src/Init/Grind/Ring/Basic.lean | 3 ++ .../Meta/Tactic/Grind/Arith/CommRing.lean | 1 + .../Tactic/Grind/Arith/CommRing/Power.lean | 37 +++++++++++++++++++ src/Lean/Meta/Tactic/Grind/Arith/Simproc.lean | 27 ++++++++------ tests/lean/run/grind_10661.lean | 12 ++++++ 5 files changed, 68 insertions(+), 12 deletions(-) create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/Power.lean create mode 100644 tests/lean/run/grind_10661.lean diff --git a/src/Init/Grind/Ring/Basic.lean b/src/Init/Grind/Ring/Basic.lean index b23dc12e2c..ad71182c30 100644 --- a/src/Init/Grind/Ring/Basic.lean +++ b/src/Init/Grind/Ring/Basic.lean @@ -201,6 +201,9 @@ theorem pow_add (a : α) (k₁ k₂ : Nat) : a ^ (k₁ + k₂) = a^k₁ * a^k₂ next => simp [pow_zero, mul_one] next k₂ ih => rw [Nat.add_succ, pow_succ, pow_succ, ih, mul_assoc] +theorem pow_add_congr (a r : α) (k k₁ k₂ : Nat) : k = k₁ + k₂ → a^k₁ * a^k₂ = r → a ^ k = r := by + intros; subst k r; rw [pow_add] + theorem natCast_pow (x : Nat) (k : Nat) : ((x ^ k : Nat) : α) = (x : α) ^ k := by induction k next => simp [pow_zero, Nat.pow_zero, natCast_one] diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean index 546597520a..4aa6bb019b 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean @@ -26,6 +26,7 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadCanon public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadSemiring public import Lean.Meta.Tactic.Grind.Arith.CommRing.Action +public import Lean.Meta.Tactic.Grind.Arith.CommRing.Power public section namespace Lean.Meta.Grind.Arith.CommRing builtin_initialize registerTraceClass `grind.ring diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Power.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Power.lean new file mode 100644 index 0000000000..02df2a1916 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Power.lean @@ -0,0 +1,37 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.Types +import Init.Grind +import Lean.Meta.Tactic.Grind.PropagatorAttr +import Lean.Meta.Tactic.Grind.Simp +import Lean.Meta.Tactic.Grind.Arith.Simproc +import Lean.Meta.NatInstTesters +public section +namespace Lean.Meta.Grind.Arith.CommRing + +builtin_grind_propagator propagatePower ↑HPow.hPow := fun e => do + -- **Note**: We are not checking whether the `^` instance is the expected ones. + let_expr HPow.hPow α n α' _ a b := e | return () + let_expr Nat := n | return () + unless isSameExpr α α' do return () + traverseEqc b fun bENode => do + let b' := bENode.self + match_expr b' with + | HAdd.hAdd n₁ n₂ n₃ inst b₁ b₂ => + unless isSameExpr n n₁ && isSameExpr n n₂ && isSameExpr n n₃ do return () + unless (← isInstHAddNat inst) do return () + let pwFn := e.appFn!.appFn! + let r ← mkMul (mkApp2 pwFn a b₁) (mkApp2 pwFn a b₂) + let r ← preprocess r + internalize r.expr (← getGeneration e) + let some h ← mkSemiringThm ``Grind.Semiring.pow_add_congr α | return () + let h := mkApp7 h a r.expr b b₁ b₂ (← mkEqProof b b') (← r.getProof) + pushEq e r.expr h + | _ => return () + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Simproc.lean b/src/Lean/Meta/Tactic/Grind/Arith/Simproc.lean index 01ccb24212..3869cbd10a 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Simproc.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Simproc.lean @@ -13,21 +13,30 @@ import Init.Grind.Ring.Field public section namespace Lean.Meta.Grind.Arith -private def mkSemiringThm (declName : Name) (α : Expr) : MetaM (Option Expr) := do +def mkSemiringThm (declName : Name) (α : Expr) : MetaM (Option Expr) := do let some u ← getDecLevel? α | return none let semiring := mkApp (mkConst ``Grind.Semiring [u]) α let some semiringInst ← synthInstanceMeta? semiring | return none return mkApp2 (mkConst declName [u]) α semiringInst /-- -Applies `a^(m+n) = a^m * a^n`, `a^0 = 1`, `a^1 = a`. +Applies `a^0 = 1`, `a^1 = a`. We do normalize `a^0` and `a^1` when converting expressions into polynomials, but we need to normalize them here when for other preprocessing steps such as `a / b = a*b⁻¹`. If `b` is of the form `c^1`, it will be treated as an -atom in the comm ring module. +atom in the ring module. + +**Note**: We used to expand `a^(n+m)` here, but it prevented `grind` from solving +simple problems such as +``` +example {k : Nat} (h : k - 1 + 1 = k) : + 2 ^ (k - 1 + 1) = 2 ^ k := by + grind +``` +We now use a propagator for `a^(n+m)` which adds the `a^n*a^m` to the equivalence class. -/ -builtin_simproc_decl expandPowAdd (_ ^ _) := fun e => do +builtin_simproc_decl expandPow01 (_ ^ _) := fun e => do let_expr HPow.hPow α nat α' _ a k := e | return .continue let_expr Nat ← nat | return .continue if let some k ← getNatValue? k then @@ -42,13 +51,7 @@ builtin_simproc_decl expandPowAdd (_ ^ _) := fun e => do return .done { expr := a, proof? := some (mkApp h a) } else return .continue - else - let_expr HAdd.hAdd _ _ _ _ m n := k | return .continue - unless (← isDefEq α α') do return .continue - let some h ← mkSemiringThm ``Grind.Semiring.pow_add α | return .continue - let pwFn := e.appFn!.appFn! - let r ← mkMul (mkApp2 pwFn a m) (mkApp2 pwFn a n) - return .visit { expr := r, proof? := some (mkApp3 h a m n) } + return .continue private def notField : Std.HashSet Name := [``Nat, ``Int, ``BitVec, ``UInt8, ``UInt16, ``UInt32, ``Int64, ``Int8, ``Int16, ``Int32, ``Int64].foldl (init := {}) (·.insert ·) @@ -185,7 +188,7 @@ Add additional arithmetic simprocs -/ def addSimproc (s : Simprocs) : CoreM Simprocs := do - let s ← s.add ``expandPowAdd (post := true) + let s ← s.add ``expandPow01 (post := true) let s ← s.add ``expandDiv (post := true) let s ← s.add ``normNatAddInst (post := false) let s ← s.add ``normNatMulInst (post := false) diff --git a/tests/lean/run/grind_10661.lean b/tests/lean/run/grind_10661.lean new file mode 100644 index 0000000000..87c9e92d26 --- /dev/null +++ b/tests/lean/run/grind_10661.lean @@ -0,0 +1,12 @@ +example {k : Nat} (h : k - 1 + 1 = k) : + 2 ^ (k - 1 + 1) = 2 ^ k := by + grind + +example (h : a = b + c) : 2 ^ a = 2^b * 2^c := by + grind + +example (h : a = c + b) : 2 ^ a = 2^b * 2^c := by + grind + +example (h : a = 1 + b) : 2 ^ a = 2^b * 2 := by + grind