feat: pow support in grind cutsat (#10071)
This PR improves support for `a^n` in `grind cutsat`. For example, if `cutsat` discovers that `a` and `b` are equal to numerals, it now propagates the equality. This PR is similar to #9996, but `a^b`. Example: ```lean example (n : Nat) : n = 2 → 2 ^ (n+1) = 8 := by grind ``` With #10022, it also improves the support for `BitVec n` when `n` is not numeral. Example: ```lean example {n m : Nat} (x : BitVec n) : 2 ≤ n → n ≤ m → m = 2 → x = 0 ∨ x = 1 ∨ x = 2 ∨ x = 3 := by grind ```
This commit is contained in:
parent
1f9bba9d39
commit
a63d483258
8 changed files with 62 additions and 3 deletions
|
|
@ -2198,6 +2198,9 @@ theorem mod_eq (a b k : Int) (h : b = k) : a % b = a % k := by simp [*]
|
|||
theorem div_eq' (a b b' k : Int) (h₁ : b = b') (h₂ : k == a/b') : a / b = k := by simp_all
|
||||
theorem mod_eq' (a b b' k : Int) (h₁ : b = b') (h₂ : k == a%b') : a % b = k := by simp_all
|
||||
|
||||
theorem pow_eq (a : Int) (b : Nat) (a' b' k : Int) (h₁ : a = a') (h₂ : ↑b = b') (h₃ : k == a'^b'.toNat) : a^b = k := by
|
||||
simp [← h₁, ← h₂] at h₃; simp [h₃]
|
||||
|
||||
end Int.Linear
|
||||
|
||||
theorem Int.not_le_eq (a b : Int) : (¬a ≤ b) = (b + 1 ≤ a) := by
|
||||
|
|
|
|||
|
|
@ -249,6 +249,29 @@ private def propagateNonlinearMod (x : Var) : GoalM Bool := do
|
|||
c'.assert
|
||||
return true
|
||||
|
||||
private def propagateNonlinearPow (x : Var) : GoalM Bool := do
|
||||
let e ← getVar x
|
||||
let_expr HPow.hPow _ _ _ i a b := e | return false
|
||||
unless (← isInstHPowInt i) do return false
|
||||
let (ka, ca?) ← if let some ka ← getIntValue? a then
|
||||
pure (ka, none)
|
||||
else if let some (ka, ca) ← isExprEqConst? a then
|
||||
pure (ka, some ca)
|
||||
else
|
||||
return false
|
||||
let (kb, cb?) ← if let some kb ← getNatValue? b then
|
||||
pure (kb, none)
|
||||
else
|
||||
let (b', _) ← mkNatVar b
|
||||
if let some (kb, cb) ← isExprEqConst? b' then
|
||||
pure (kb.toNat, some cb)
|
||||
else
|
||||
return false
|
||||
trace[Meta.debug] ">> e: {e}, k: {ka^kb}"
|
||||
let c' ← pure { p := .add 1 x (.num (-(ka^kb))), h := .pow ka ca? kb cb? : EqCnstr }
|
||||
c'.assert
|
||||
return true
|
||||
|
||||
@[export lean_cutsat_propagate_nonlinear]
|
||||
def propagateNonlinearTermImpl (y : Var) (x : Var) : GoalM Bool := do
|
||||
unless (← isVarEqConst? y).isSome do return false
|
||||
|
|
@ -256,6 +279,7 @@ def propagateNonlinearTermImpl (y : Var) (x : Var) : GoalM Bool := do
|
|||
| HMul.hMul _ _ _ _ _ _ => propagateNonlinearMul x
|
||||
| HDiv.hDiv _ _ _ _ _ _ => propagateNonlinearDiv x
|
||||
| HMod.hMod _ _ _ _ _ _ => propagateNonlinearMod x
|
||||
| HPow.hPow _ _ _ _ _ _ => propagateNonlinearPow x
|
||||
| _ => return false
|
||||
|
||||
def propagateNonlinearTerms (y : Var) : GoalM Unit := do
|
||||
|
|
|
|||
|
|
@ -16,10 +16,11 @@ public section
|
|||
|
||||
namespace Lean.Meta.Grind.Arith.Cutsat
|
||||
|
||||
/-- Given `e`, returns `(NatCast.natCast e, rfl)` -/
|
||||
def mkNatVar (e : Expr) : GoalM (Expr × Expr) := do
|
||||
if let some p := (← get').natToIntMap.find? { expr := e } then
|
||||
return p
|
||||
let e' := mkIntNatCast e
|
||||
let e' ← shareCommon (mkIntNatCast e)
|
||||
let he := mkApp (mkApp (mkConst ``Eq.refl [1]) Int.mkType) e'
|
||||
let r := (e', he)
|
||||
modify' fun s => { s with
|
||||
|
|
|
|||
|
|
@ -309,6 +309,22 @@ partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do
|
|||
let k := aVal % b'
|
||||
let h := mkApp6 (mkConst ``Int.Linear.mod_eq') a b (toExpr b') (toExpr k) h eagerReflBoolTrue
|
||||
return mkApp6 (mkConst ``Int.Linear.of_var_eq) (← getContext) (← mkVarDecl x) (toExpr k) (← mkPolyDecl c'.p) eagerReflBoolTrue h
|
||||
| .pow ka ca? kb cb? =>
|
||||
let .add _ x _ := c'.p | c'.throwUnexpected
|
||||
let_expr HPow.hPow _ _ _ _ a b := (← getCurrVars)[x]! | c'.throwUnexpected
|
||||
let h₁ ← if let some ca := ca? then
|
||||
pure <| mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl (← getVarOf a)) (toExpr ka) (← mkPolyDecl ca.p) eagerReflBoolTrue (← ca.toExprProof)
|
||||
else
|
||||
pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit ka)
|
||||
let kbInt := Int.ofNat kb
|
||||
let h₂ ← if let some cb := cb? then
|
||||
let (b', _) ← mkNatVar b
|
||||
pure <| mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl (← getVarOf b')) (toExpr kbInt) (← mkPolyDecl cb.p) eagerReflBoolTrue (← cb.toExprProof)
|
||||
else
|
||||
pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit kb)
|
||||
let k := ka^kb
|
||||
let h := mkApp8 (mkConst ``Int.Linear.pow_eq) a b (toExpr ka) (toExpr kbInt) (toExpr k) h₁ h₂ eagerReflBoolTrue
|
||||
return mkApp6 (mkConst ``Int.Linear.of_var_eq) (← getContext) (← mkVarDecl x) (toExpr k) (← mkPolyDecl c'.p) eagerReflBoolTrue h
|
||||
|
||||
partial def mkMulEqProof (x : Var) (a? : Option Expr) (cs : Array (Expr × Int × EqCnstr)) (c' : EqCnstr) : ProofM Expr := do
|
||||
let h ← go (← getCurrVars)[x]!
|
||||
|
|
@ -611,6 +627,7 @@ partial def EqCnstr.collectDecVars (c' : EqCnstr) : CollectDecVarsM Unit := do u
|
|||
| .commRingNorm c .. | .reorder c | .norm c | .divCoeffs c | .div _ _ c | .mod _ _ c => c.collectDecVars
|
||||
| .subst _ c₁ c₂ | .ofLeGe c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
|
||||
| .mul _ cs => cs.forM fun (_, _, c) => c.collectDecVars
|
||||
| .pow _ ca? _ cb? => ca?.forM (·.collectDecVars); cb?.forM (·.collectDecVars)
|
||||
|
||||
partial def CooperSplit.collectDecVars (s : CooperSplit) : CollectDecVarsM Unit := do unless (← alreadyVisited s) do
|
||||
s.pred.c₁.collectDecVars
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ inductive EqCnstrProof where
|
|||
- If `?y = none`, then it is a proof for `a % b = a%k` where `c` is a proof that `b = k`. `a` is a numeral in this case.
|
||||
-/
|
||||
mod (k : Int) (y? : Option Var) (c : EqCnstr)
|
||||
| pow (ka : Int) (ca? : Option EqCnstr) (kb : Nat) (cb? : Option EqCnstr)
|
||||
|
||||
/-- A divisibility constraint and its justification/proof. -/
|
||||
structure DvdCnstr where
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ private def isNonlinearTerm (e : Expr) : MetaM Bool := do
|
|||
| HMul.hMul _ _ _ i _ _ => isInstHMulInt i
|
||||
| HDiv.hDiv _ _ _ i _ b => pure (← getIntValue? b).isNone <&&> isInstHDivInt i
|
||||
| HMod.hMod _ _ _ i _ b => pure (← getIntValue? b).isNone <&&> isInstHModInt i
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
unless (← isInstHPowInt i) do return false
|
||||
return (← getIntValue? a).isNone || (← getIntValue? b).isNone
|
||||
| _ => return false
|
||||
|
||||
private def registerNonlinearOcc (arg : Expr) (x : Var) : GoalM Unit := do
|
||||
|
|
@ -31,7 +34,6 @@ private def registerNonlinearOcc (arg : Expr) (x : Var) : GoalM Unit := do
|
|||
if (← get').elimEqs[y]!.isSome then
|
||||
if (← propagateNonlinearTerm y x) then
|
||||
return ()
|
||||
let y ← mkVar arg
|
||||
let occs := (← get').nonlinearOccs.find? y |>.getD []
|
||||
unless x ∈ occs do
|
||||
modify' fun s => { s with nonlinearOccs := s.nonlinearOccs.insert y (x::occs) }
|
||||
|
|
@ -41,6 +43,14 @@ private partial def registerNonlinearOccsAt (e : Expr) (x : Var) : GoalM Unit :=
|
|||
| HMul.hMul _ _ _ _ a b => go a; go b
|
||||
| HDiv.hDiv _ _ _ _ _ b => registerNonlinearOcc b x
|
||||
| HMod.hMod _ _ _ _ _ b => registerNonlinearOcc b x
|
||||
| HPow.hPow _ _ _ _ a b =>
|
||||
if (← getIntValue? a).isNone then
|
||||
registerNonlinearOcc a x
|
||||
if (← getIntValue? b).isNone then
|
||||
-- Recall that `b : Nat`, we must create `NatCast.natCast b` and watch it.
|
||||
let (b', _) ← mkNatVar b
|
||||
internalize b' (← getGeneration b)
|
||||
registerNonlinearOcc b' x
|
||||
| _ => return ()
|
||||
where
|
||||
go (e : Expr) : GoalM Unit := do
|
||||
|
|
|
|||
|
|
@ -114,3 +114,6 @@ example {n : Nat} (j : Fin (n + 1)) : j ≤ j := by
|
|||
|
||||
example {n : Nat} (x y : Fin ((n + 1) + 1)) (h₂ : ¬x = y) (h : ¬x < y) : y < x := by
|
||||
grind
|
||||
|
||||
example {n m : Nat} (x : BitVec n) : 2 ≤ n → n ≤ m → m = 2 → x = 0 ∨ x = 1 ∨ x = 2 ∨ x = 3 := by
|
||||
grind
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ example [CommRing α] [IsCharP α 8] (x : α) : (x + 1)*(x - 1) = x^2 → False
|
|||
#guard_msgs (trace) in
|
||||
set_option trace.grind.ring.assert.queue true in
|
||||
example (x y : Int) : x + 16*y^2 - 7*x^2 = 0 → False := by
|
||||
fail_if_success grind
|
||||
fail_if_success grind -cutsat
|
||||
sorry
|
||||
|
||||
/--
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue