feat: pow support in grind cutsat (#10071)

This PR improves support for `a^n` in `grind cutsat`. For example, if
`cutsat` discovers that `a` and `b` are equal to numerals, it now
propagates the equality. This PR is similar to #9996, but `a^b`.
Example:

```lean
example (n : Nat) : n = 2 → 2 ^ (n+1) = 8 := by
  grind
```

With #10022, it also improves the support for `BitVec n` when `n` is not
numeral. Example:

```lean
example {n m : Nat} (x : BitVec n)
    : 2 ≤ n → n ≤ m → m = 2 → x = 0 ∨ x = 1 ∨ x = 2 ∨ x = 3 := by
  grind
```
This commit is contained in:
Leonardo de Moura 2025-08-22 18:55:05 -07:00 committed by GitHub
parent 1f9bba9d39
commit a63d483258
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 62 additions and 3 deletions

View file

@ -2198,6 +2198,9 @@ theorem mod_eq (a b k : Int) (h : b = k) : a % b = a % k := by simp [*]
theorem div_eq' (a b b' k : Int) (h₁ : b = b') (h₂ : k == a/b') : a / b = k := by simp_all
theorem mod_eq' (a b b' k : Int) (h₁ : b = b') (h₂ : k == a%b') : a % b = k := by simp_all
theorem pow_eq (a : Int) (b : Nat) (a' b' k : Int) (h₁ : a = a') (h₂ : ↑b = b') (h₃ : k == a'^b'.toNat) : a^b = k := by
simp [← h₁, ← h₂] at h₃; simp [h₃]
end Int.Linear
theorem Int.not_le_eq (a b : Int) : (¬a ≤ b) = (b + 1 ≤ a) := by

View file

@ -249,6 +249,29 @@ private def propagateNonlinearMod (x : Var) : GoalM Bool := do
c'.assert
return true
private def propagateNonlinearPow (x : Var) : GoalM Bool := do
let e ← getVar x
let_expr HPow.hPow _ _ _ i a b := e | return false
unless (← isInstHPowInt i) do return false
let (ka, ca?) ← if let some ka ← getIntValue? a then
pure (ka, none)
else if let some (ka, ca) ← isExprEqConst? a then
pure (ka, some ca)
else
return false
let (kb, cb?) ← if let some kb ← getNatValue? b then
pure (kb, none)
else
let (b', _) ← mkNatVar b
if let some (kb, cb) ← isExprEqConst? b' then
pure (kb.toNat, some cb)
else
return false
trace[Meta.debug] ">> e: {e}, k: {ka^kb}"
let c' ← pure { p := .add 1 x (.num (-(ka^kb))), h := .pow ka ca? kb cb? : EqCnstr }
c'.assert
return true
@[export lean_cutsat_propagate_nonlinear]
def propagateNonlinearTermImpl (y : Var) (x : Var) : GoalM Bool := do
unless (← isVarEqConst? y).isSome do return false
@ -256,6 +279,7 @@ def propagateNonlinearTermImpl (y : Var) (x : Var) : GoalM Bool := do
| HMul.hMul _ _ _ _ _ _ => propagateNonlinearMul x
| HDiv.hDiv _ _ _ _ _ _ => propagateNonlinearDiv x
| HMod.hMod _ _ _ _ _ _ => propagateNonlinearMod x
| HPow.hPow _ _ _ _ _ _ => propagateNonlinearPow x
| _ => return false
def propagateNonlinearTerms (y : Var) : GoalM Unit := do

View file

@ -16,10 +16,11 @@ public section
namespace Lean.Meta.Grind.Arith.Cutsat
/-- Given `e`, returns `(NatCast.natCast e, rfl)` -/
def mkNatVar (e : Expr) : GoalM (Expr × Expr) := do
if let some p := (← get').natToIntMap.find? { expr := e } then
return p
let e' := mkIntNatCast e
let e' ← shareCommon (mkIntNatCast e)
let he := mkApp (mkApp (mkConst ``Eq.refl [1]) Int.mkType) e'
let r := (e', he)
modify' fun s => { s with

View file

@ -309,6 +309,22 @@ partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do
let k := aVal % b'
let h := mkApp6 (mkConst ``Int.Linear.mod_eq') a b (toExpr b') (toExpr k) h eagerReflBoolTrue
return mkApp6 (mkConst ``Int.Linear.of_var_eq) (← getContext) (← mkVarDecl x) (toExpr k) (← mkPolyDecl c'.p) eagerReflBoolTrue h
| .pow ka ca? kb cb? =>
let .add _ x _ := c'.p | c'.throwUnexpected
let_expr HPow.hPow _ _ _ _ a b := (← getCurrVars)[x]! | c'.throwUnexpected
let h₁ ← if let some ca := ca? then
pure <| mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl (← getVarOf a)) (toExpr ka) (← mkPolyDecl ca.p) eagerReflBoolTrue (← ca.toExprProof)
else
pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit ka)
let kbInt := Int.ofNat kb
let h₂ ← if let some cb := cb? then
let (b', _) ← mkNatVar b
pure <| mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl (← getVarOf b')) (toExpr kbInt) (← mkPolyDecl cb.p) eagerReflBoolTrue (← cb.toExprProof)
else
pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit kb)
let k := ka^kb
let h := mkApp8 (mkConst ``Int.Linear.pow_eq) a b (toExpr ka) (toExpr kbInt) (toExpr k) h₁ h₂ eagerReflBoolTrue
return mkApp6 (mkConst ``Int.Linear.of_var_eq) (← getContext) (← mkVarDecl x) (toExpr k) (← mkPolyDecl c'.p) eagerReflBoolTrue h
partial def mkMulEqProof (x : Var) (a? : Option Expr) (cs : Array (Expr × Int × EqCnstr)) (c' : EqCnstr) : ProofM Expr := do
let h ← go (← getCurrVars)[x]!
@ -611,6 +627,7 @@ partial def EqCnstr.collectDecVars (c' : EqCnstr) : CollectDecVarsM Unit := do u
| .commRingNorm c .. | .reorder c | .norm c | .divCoeffs c | .div _ _ c | .mod _ _ c => c.collectDecVars
| .subst _ c₁ c₂ | .ofLeGe c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
| .mul _ cs => cs.forM fun (_, _, c) => c.collectDecVars
| .pow _ ca? _ cb? => ca?.forM (·.collectDecVars); cb?.forM (·.collectDecVars)
partial def CooperSplit.collectDecVars (s : CooperSplit) : CollectDecVarsM Unit := do unless (← alreadyVisited s) do
s.pred.c₁.collectDecVars

View file

@ -115,6 +115,7 @@ inductive EqCnstrProof where
- If `?y = none`, then it is a proof for `a % b = a%k` where `c` is a proof that `b = k`. `a` is a numeral in this case.
-/
mod (k : Int) (y? : Option Var) (c : EqCnstr)
| pow (ka : Int) (ca? : Option EqCnstr) (kb : Nat) (cb? : Option EqCnstr)
/-- A divisibility constraint and its justification/proof. -/
structure DvdCnstr where

View file

@ -24,6 +24,9 @@ private def isNonlinearTerm (e : Expr) : MetaM Bool := do
| HMul.hMul _ _ _ i _ _ => isInstHMulInt i
| HDiv.hDiv _ _ _ i _ b => pure (← getIntValue? b).isNone <&&> isInstHDivInt i
| HMod.hMod _ _ _ i _ b => pure (← getIntValue? b).isNone <&&> isInstHModInt i
| HPow.hPow _ _ _ i a b =>
unless (← isInstHPowInt i) do return false
return (← getIntValue? a).isNone || (← getIntValue? b).isNone
| _ => return false
private def registerNonlinearOcc (arg : Expr) (x : Var) : GoalM Unit := do
@ -31,7 +34,6 @@ private def registerNonlinearOcc (arg : Expr) (x : Var) : GoalM Unit := do
if (← get').elimEqs[y]!.isSome then
if (← propagateNonlinearTerm y x) then
return ()
let y ← mkVar arg
let occs := (← get').nonlinearOccs.find? y |>.getD []
unless x ∈ occs do
modify' fun s => { s with nonlinearOccs := s.nonlinearOccs.insert y (x::occs) }
@ -41,6 +43,14 @@ private partial def registerNonlinearOccsAt (e : Expr) (x : Var) : GoalM Unit :=
| HMul.hMul _ _ _ _ a b => go a; go b
| HDiv.hDiv _ _ _ _ _ b => registerNonlinearOcc b x
| HMod.hMod _ _ _ _ _ b => registerNonlinearOcc b x
| HPow.hPow _ _ _ _ a b =>
if (← getIntValue? a).isNone then
registerNonlinearOcc a x
if (← getIntValue? b).isNone then
-- Recall that `b : Nat`, we must create `NatCast.natCast b` and watch it.
let (b', _) ← mkNatVar b
internalize b' (← getGeneration b)
registerNonlinearOcc b' x
| _ => return ()
where
go (e : Expr) : GoalM Unit := do

View file

@ -114,3 +114,6 @@ example {n : Nat} (j : Fin (n + 1)) : j ≤ j := by
example {n : Nat} (x y : Fin ((n + 1) + 1)) (h₂ : ¬x = y) (h : ¬x < y) : y < x := by
grind
example {n m : Nat} (x : BitVec n) : 2 ≤ n → n ≤ m → m = 2 → x = 0 x = 1 x = 2 x = 3 := by
grind

View file

@ -66,7 +66,7 @@ example [CommRing α] [IsCharP α 8] (x : α) : (x + 1)*(x - 1) = x^2 → False
#guard_msgs (trace) in
set_option trace.grind.ring.assert.queue true in
example (x y : Int) : x + 16*y^2 - 7*x^2 = 0 → False := by
fail_if_success grind
fail_if_success grind -cutsat
sorry
/--