diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean index 8ccb80d868..9fa4005031 100644 --- a/src/Init/Data/Int/Linear.lean +++ b/src/Init/Data/Int/Linear.lean @@ -2198,6 +2198,9 @@ theorem mod_eq (a b k : Int) (h : b = k) : a % b = a % k := by simp [*] theorem div_eq' (a b b' k : Int) (h₁ : b = b') (h₂ : k == a/b') : a / b = k := by simp_all theorem mod_eq' (a b b' k : Int) (h₁ : b = b') (h₂ : k == a%b') : a % b = k := by simp_all +theorem pow_eq (a : Int) (b : Nat) (a' b' k : Int) (h₁ : a = a') (h₂ : ↑b = b') (h₃ : k == a'^b'.toNat) : a^b = k := by + simp [← h₁, ← h₂] at h₃; simp [h₃] + end Int.Linear theorem Int.not_le_eq (a b : Int) : (¬a ≤ b) = (b + 1 ≤ a) := by diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean index aca0a45ea5..ca8ce5aa39 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean @@ -249,6 +249,29 @@ private def propagateNonlinearMod (x : Var) : GoalM Bool := do c'.assert return true +private def propagateNonlinearPow (x : Var) : GoalM Bool := do + let e ← getVar x + let_expr HPow.hPow _ _ _ i a b := e | return false + unless (← isInstHPowInt i) do return false + let (ka, ca?) ← if let some ka ← getIntValue? a then + pure (ka, none) + else if let some (ka, ca) ← isExprEqConst? a then + pure (ka, some ca) + else + return false + let (kb, cb?) ← if let some kb ← getNatValue? b then + pure (kb, none) + else + let (b', _) ← mkNatVar b + if let some (kb, cb) ← isExprEqConst? b' then + pure (kb.toNat, some cb) + else + return false + trace[Meta.debug] ">> e: {e}, k: {ka^kb}" + let c' ← pure { p := .add 1 x (.num (-(ka^kb))), h := .pow ka ca? kb cb? : EqCnstr } + c'.assert + return true + @[export lean_cutsat_propagate_nonlinear] def propagateNonlinearTermImpl (y : Var) (x : Var) : GoalM Bool := do unless (← isVarEqConst? y).isSome do return false @@ -256,6 +279,7 @@ def propagateNonlinearTermImpl (y : Var) (x : Var) : GoalM Bool := do | HMul.hMul _ _ _ _ _ _ => propagateNonlinearMul x | HDiv.hDiv _ _ _ _ _ _ => propagateNonlinearDiv x | HMod.hMod _ _ _ _ _ _ => propagateNonlinearMod x + | HPow.hPow _ _ _ _ _ _ => propagateNonlinearPow x | _ => return false def propagateNonlinearTerms (y : Var) : GoalM Unit := do diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Nat.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Nat.lean index d0ef0ff483..4ae2363f61 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Nat.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Nat.lean @@ -16,10 +16,11 @@ public section namespace Lean.Meta.Grind.Arith.Cutsat +/-- Given `e`, returns `(NatCast.natCast e, rfl)` -/ def mkNatVar (e : Expr) : GoalM (Expr × Expr) := do if let some p := (← get').natToIntMap.find? { expr := e } then return p - let e' := mkIntNatCast e + let e' ← shareCommon (mkIntNatCast e) let he := mkApp (mkApp (mkConst ``Eq.refl [1]) Int.mkType) e' let r := (e', he) modify' fun s => { s with diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean index 91de39c1dd..63d402b44e 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean @@ -309,6 +309,22 @@ partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do let k := aVal % b' let h := mkApp6 (mkConst ``Int.Linear.mod_eq') a b (toExpr b') (toExpr k) h eagerReflBoolTrue return mkApp6 (mkConst ``Int.Linear.of_var_eq) (← getContext) (← mkVarDecl x) (toExpr k) (← mkPolyDecl c'.p) eagerReflBoolTrue h + | .pow ka ca? kb cb? => + let .add _ x _ := c'.p | c'.throwUnexpected + let_expr HPow.hPow _ _ _ _ a b := (← getCurrVars)[x]! | c'.throwUnexpected + let h₁ ← if let some ca := ca? then + pure <| mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl (← getVarOf a)) (toExpr ka) (← mkPolyDecl ca.p) eagerReflBoolTrue (← ca.toExprProof) + else + pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit ka) + let kbInt := Int.ofNat kb + let h₂ ← if let some cb := cb? then + let (b', _) ← mkNatVar b + pure <| mkApp6 (mkConst ``Int.Linear.var_eq) (← getContext) (← mkVarDecl (← getVarOf b')) (toExpr kbInt) (← mkPolyDecl cb.p) eagerReflBoolTrue (← cb.toExprProof) + else + pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit kb) + let k := ka^kb + let h := mkApp8 (mkConst ``Int.Linear.pow_eq) a b (toExpr ka) (toExpr kbInt) (toExpr k) h₁ h₂ eagerReflBoolTrue + return mkApp6 (mkConst ``Int.Linear.of_var_eq) (← getContext) (← mkVarDecl x) (toExpr k) (← mkPolyDecl c'.p) eagerReflBoolTrue h partial def mkMulEqProof (x : Var) (a? : Option Expr) (cs : Array (Expr × Int × EqCnstr)) (c' : EqCnstr) : ProofM Expr := do let h ← go (← getCurrVars)[x]! @@ -611,6 +627,7 @@ partial def EqCnstr.collectDecVars (c' : EqCnstr) : CollectDecVarsM Unit := do u | .commRingNorm c .. | .reorder c | .norm c | .divCoeffs c | .div _ _ c | .mod _ _ c => c.collectDecVars | .subst _ c₁ c₂ | .ofLeGe c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars | .mul _ cs => cs.forM fun (_, _, c) => c.collectDecVars + | .pow _ ca? _ cb? => ca?.forM (·.collectDecVars); cb?.forM (·.collectDecVars) partial def CooperSplit.collectDecVars (s : CooperSplit) : CollectDecVarsM Unit := do unless (← alreadyVisited s) do s.pred.c₁.collectDecVars diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean index 410e0bca65..d51a2a46f4 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean @@ -115,6 +115,7 @@ inductive EqCnstrProof where - If `?y = none`, then it is a proof for `a % b = a%k` where `c` is a proof that `b = k`. `a` is a numeral in this case. -/ mod (k : Int) (y? : Option Var) (c : EqCnstr) + | pow (ka : Int) (ca? : Option EqCnstr) (kb : Nat) (cb? : Option EqCnstr) /-- A divisibility constraint and its justification/proof. -/ structure DvdCnstr where diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Var.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Var.lean index 22dc6ebfaf..7d7b7f388f 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Var.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Var.lean @@ -24,6 +24,9 @@ private def isNonlinearTerm (e : Expr) : MetaM Bool := do | HMul.hMul _ _ _ i _ _ => isInstHMulInt i | HDiv.hDiv _ _ _ i _ b => pure (← getIntValue? b).isNone <&&> isInstHDivInt i | HMod.hMod _ _ _ i _ b => pure (← getIntValue? b).isNone <&&> isInstHModInt i + | HPow.hPow _ _ _ i a b => + unless (← isInstHPowInt i) do return false + return (← getIntValue? a).isNone || (← getIntValue? b).isNone | _ => return false private def registerNonlinearOcc (arg : Expr) (x : Var) : GoalM Unit := do @@ -31,7 +34,6 @@ private def registerNonlinearOcc (arg : Expr) (x : Var) : GoalM Unit := do if (← get').elimEqs[y]!.isSome then if (← propagateNonlinearTerm y x) then return () - let y ← mkVar arg let occs := (← get').nonlinearOccs.find? y |>.getD [] unless x ∈ occs do modify' fun s => { s with nonlinearOccs := s.nonlinearOccs.insert y (x::occs) } @@ -41,6 +43,14 @@ private partial def registerNonlinearOccsAt (e : Expr) (x : Var) : GoalM Unit := | HMul.hMul _ _ _ _ a b => go a; go b | HDiv.hDiv _ _ _ _ _ b => registerNonlinearOcc b x | HMod.hMod _ _ _ _ _ b => registerNonlinearOcc b x + | HPow.hPow _ _ _ _ a b => + if (← getIntValue? a).isNone then + registerNonlinearOcc a x + if (← getIntValue? b).isNone then + -- Recall that `b : Nat`, we must create `NatCast.natCast b` and watch it. + let (b', _) ← mkNatVar b + internalize b' (← getGeneration b) + registerNonlinearOcc b' x | _ => return () where go (e : Expr) : GoalM Unit := do diff --git a/tests/lean/run/grind_cutsat_toint_1.lean b/tests/lean/run/grind_cutsat_toint_1.lean index 16390e356b..8550769047 100644 --- a/tests/lean/run/grind_cutsat_toint_1.lean +++ b/tests/lean/run/grind_cutsat_toint_1.lean @@ -114,3 +114,6 @@ example {n : Nat} (j : Fin (n + 1)) : j ≤ j := by example {n : Nat} (x y : Fin ((n + 1) + 1)) (h₂ : ¬x = y) (h : ¬x < y) : y < x := by grind + +example {n m : Nat} (x : BitVec n) : 2 ≤ n → n ≤ m → m = 2 → x = 0 ∨ x = 1 ∨ x = 2 ∨ x = 3 := by + grind diff --git a/tests/lean/run/grind_ring_1.lean b/tests/lean/run/grind_ring_1.lean index 1a8d85f307..4ead128179 100644 --- a/tests/lean/run/grind_ring_1.lean +++ b/tests/lean/run/grind_ring_1.lean @@ -66,7 +66,7 @@ example [CommRing α] [IsCharP α 8] (x : α) : (x + 1)*(x - 1) = x^2 → False #guard_msgs (trace) in set_option trace.grind.ring.assert.queue true in example (x y : Int) : x + 16*y^2 - 7*x^2 = 0 → False := by - fail_if_success grind + fail_if_success grind -cutsat sorry /--