fix: propagator for a^(n+m) in grind (#10964)

This PR adds a propagator for `a^(n+m)` and removes its normalizer. This
change was motivated by issue #10661

Closes #10661
This commit is contained in:
Leonardo de Moura 2025-10-25 20:52:28 -07:00 committed by GitHub
parent 93c5bd0fdd
commit f8b0beeba9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 68 additions and 12 deletions

View file

@ -201,6 +201,9 @@ theorem pow_add (a : α) (k₁ k₂ : Nat) : a ^ (k₁ + k₂) = a^k₁ * a^k₂
next => simp [pow_zero, mul_one]
next k₂ ih => rw [Nat.add_succ, pow_succ, pow_succ, ih, mul_assoc]
theorem pow_add_congr (a r : α) (k k₁ k₂ : Nat) : k = k₁ + k₂ → a^k₁ * a^k₂ = r → a ^ k = r := by
intros; subst k r; rw [pow_add]
theorem natCast_pow (x : Nat) (k : Nat) : ((x ^ k : Nat) : α) = (x : α) ^ k := by
induction k
next => simp [pow_zero, Nat.pow_zero, natCast_one]

View file

@ -26,6 +26,7 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadCanon
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadSemiring
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Action
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Power
public section
namespace Lean.Meta.Grind.Arith.CommRing
builtin_initialize registerTraceClass `grind.ring

View file

@ -0,0 +1,37 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.Types
import Init.Grind
import Lean.Meta.Tactic.Grind.PropagatorAttr
import Lean.Meta.Tactic.Grind.Simp
import Lean.Meta.Tactic.Grind.Arith.Simproc
import Lean.Meta.NatInstTesters
public section
namespace Lean.Meta.Grind.Arith.CommRing
builtin_grind_propagator propagatePower ↑HPow.hPow := fun e => do
-- **Note**: We are not checking whether the `^` instance is the expected ones.
let_expr HPow.hPow α n α' _ a b := e | return ()
let_expr Nat := n | return ()
unless isSameExpr α α' do return ()
traverseEqc b fun bENode => do
let b' := bENode.self
match_expr b' with
| HAdd.hAdd n₁ n₂ n₃ inst b₁ b₂ =>
unless isSameExpr n n₁ && isSameExpr n n₂ && isSameExpr n n₃ do return ()
unless (← isInstHAddNat inst) do return ()
let pwFn := e.appFn!.appFn!
let r ← mkMul (mkApp2 pwFn a b₁) (mkApp2 pwFn a b₂)
let r ← preprocess r
internalize r.expr (← getGeneration e)
let some h ← mkSemiringThm ``Grind.Semiring.pow_add_congr α | return ()
let h := mkApp7 h a r.expr b b₁ b₂ (← mkEqProof b b') (← r.getProof)
pushEq e r.expr h
| _ => return ()
end Lean.Meta.Grind.Arith.CommRing

View file

@ -13,21 +13,30 @@ import Init.Grind.Ring.Field
public section
namespace Lean.Meta.Grind.Arith
private def mkSemiringThm (declName : Name) (α : Expr) : MetaM (Option Expr) := do
def mkSemiringThm (declName : Name) (α : Expr) : MetaM (Option Expr) := do
let some u ← getDecLevel? α | return none
let semiring := mkApp (mkConst ``Grind.Semiring [u]) α
let some semiringInst ← synthInstanceMeta? semiring | return none
return mkApp2 (mkConst declName [u]) α semiringInst
/--
Applies `a^(m+n) = a^m * a^n`, `a^0 = 1`, `a^1 = a`.
Applies `a^0 = 1`, `a^1 = a`.
We do normalize `a^0` and `a^1` when converting expressions into polynomials,
but we need to normalize them here when for other preprocessing steps such as
`a / b = a*b⁻¹`. If `b` is of the form `c^1`, it will be treated as an
atom in the comm ring module.
atom in the ring module.
**Note**: We used to expand `a^(n+m)` here, but it prevented `grind` from solving
simple problems such as
```
example {k : Nat} (h : k - 1 + 1 = k) :
2 ^ (k - 1 + 1) = 2 ^ k := by
grind
```
We now use a propagator for `a^(n+m)` which adds the `a^n*a^m` to the equivalence class.
-/
builtin_simproc_decl expandPowAdd (_ ^ _) := fun e => do
builtin_simproc_decl expandPow01 (_ ^ _) := fun e => do
let_expr HPow.hPow α nat α' _ a k := e | return .continue
let_expr Nat ← nat | return .continue
if let some k ← getNatValue? k then
@ -42,13 +51,7 @@ builtin_simproc_decl expandPowAdd (_ ^ _) := fun e => do
return .done { expr := a, proof? := some (mkApp h a) }
else
return .continue
else
let_expr HAdd.hAdd _ _ _ _ m n := k | return .continue
unless (← isDefEq α α') do return .continue
let some h ← mkSemiringThm ``Grind.Semiring.pow_add α | return .continue
let pwFn := e.appFn!.appFn!
let r ← mkMul (mkApp2 pwFn a m) (mkApp2 pwFn a n)
return .visit { expr := r, proof? := some (mkApp3 h a m n) }
return .continue
private def notField : Std.HashSet Name :=
[``Nat, ``Int, ``BitVec, ``UInt8, ``UInt16, ``UInt32, ``Int64, ``Int8, ``Int16, ``Int32, ``Int64].foldl (init := {}) (·.insert ·)
@ -185,7 +188,7 @@ Add additional arithmetic simprocs
-/
def addSimproc (s : Simprocs) : CoreM Simprocs := do
let s ← s.add ``expandPowAdd (post := true)
let s ← s.add ``expandPow01 (post := true)
let s ← s.add ``expandDiv (post := true)
let s ← s.add ``normNatAddInst (post := false)
let s ← s.add ``normNatMulInst (post := false)

View file

@ -0,0 +1,12 @@
example {k : Nat} (h : k - 1 + 1 = k) :
2 ^ (k - 1 + 1) = 2 ^ k := by
grind
example (h : a = b + c) : 2 ^ a = 2^b * 2^c := by
grind
example (h : a = c + b) : 2 ^ a = 2^b * 2^c := by
grind
example (h : a = 1 + b) : 2 ^ a = 2^b * 2 := by
grind