From de27872f3ffee8f2eb13f36fbb8fec9bcdfe3b8c Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 19 Apr 2025 22:12:09 -0700 Subject: [PATCH] feat: basic `CommRing` support in `grind` (#8029) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements basic support for `CommRing` in `grind`. Terms are already being reified and normalized. We still need to process the equations, but `grind` can already prove simple examples such as: ```lean open Lean.Grind in example [CommRing α] (x : α) : (x + 1)*(x - 1) = x^2 - 1 := by grind +ring open Lean.Grind in example [CommRing α] [IsCharP α 256] (x : α) : (x + 16)*(x - 16) = x^2 := by grind +ring example (x : Int) : (x + 1)*(x - 1) = x^2 - 1 := by grind +ring example (x : UInt8) : (x + 16)*(x - 16) = x^2 := by grind +ring example (x : Int) : (x + 1)^2 - 1 = x^2 + 2*x := by grind +ring example (x : BitVec 8) : (x + 16)*(x - 16) = x^2 := by grind +ring example (x : BitVec 8) : (x + 1)^2 - 1 = x^2 + 2*x := by grind +ring ``` --- src/Init/Grind/CommRing/Basic.lean | 10 +- src/Init/Grind/CommRing/Poly.lean | 28 ++++ src/Init/Grind/Tactics.lean | 4 + .../Meta/Tactic/Grind/Arith/CommRing.lean | 17 +++ .../Meta/Tactic/Grind/Arith/CommRing/Eq.lean | 42 ++++++ .../Grind/Arith/CommRing/Internalize.lean | 45 +++++++ .../Tactic/Grind/Arith/CommRing/Proof.lean | 37 ++++++ .../Tactic/Grind/Arith/CommRing/Reify.lean | 109 ++++++++++++++++ .../Tactic/Grind/Arith/CommRing/RingId.lean | 123 ++++++++++++++++++ .../Tactic/Grind/Arith/CommRing/ToExpr.lean | 57 ++++++++ .../Tactic/Grind/Arith/CommRing/Types.lean | 60 +++++++++ .../Tactic/Grind/Arith/CommRing/Util.lean | 63 +++++++++ .../Meta/Tactic/Grind/Arith/CommRing/Var.lean | 24 ++++ .../Meta/Tactic/Grind/Arith/Cutsat/Proof.lean | 5 - .../Meta/Tactic/Grind/Arith/Internalize.lean | 2 + src/Lean/Meta/Tactic/Grind/Arith/Types.lean | 2 + src/Lean/Meta/Tactic/Grind/Core.lean | 31 ++++- src/Lean/Meta/Tactic/Grind/Diseq.lean | 5 + src/Lean/Meta/Tactic/Grind/Propagate.lean | 1 + src/Lean/Meta/Tactic/Grind/Types.lean | 52 ++++++++ tests/lean/run/grind_ring_1.lean | 23 ++++ 21 files changed, 729 insertions(+), 11 deletions(-) create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/Eq.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/Internalize.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/Reify.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/CommRing/Var.lean create mode 100644 tests/lean/run/grind_ring_1.lean diff --git a/src/Init/Grind/CommRing/Basic.lean b/src/Init/Grind/CommRing/Basic.lean index 64adff4661..23632317f2 100644 --- a/src/Init/Grind/CommRing/Basic.lean +++ b/src/Init/Grind/CommRing/Basic.lean @@ -57,7 +57,7 @@ namespace CommRing variable {α : Type u} [CommRing α] -instance : NatCast α where +instance natCastInst : NatCast α where natCast n := OfNat.ofNat n theorem natCast_zero : ((0 : Nat) : α) = 0 := rfl @@ -125,7 +125,13 @@ theorem neg_sub (a b : α) : -(a - b) = b - a := by theorem sub_self (a : α) : a - a = 0 := by rw [sub_eq_add_neg, add_neg_cancel] -instance : IntCast α where +theorem eq_of_sub_eq_zero {a b : α} : a - b = 0 → a = b := by + intro h + replace h := congrArg (. + b) h; simp only at h + rw [sub_eq_add_neg, add_assoc, neg_add_cancel, add_zero, zero_add] at h + assumption + +instance intCastInst : IntCast α where intCast n := match n with | Int.ofNat n => OfNat.ofNat n | Int.negSucc n => -OfNat.ofNat (n + 1) diff --git a/src/Init/Grind/CommRing/Poly.lean b/src/Init/Grind/CommRing/Poly.lean index e136956b6f..82b1dc1697 100644 --- a/src/Init/Grind/CommRing/Poly.lean +++ b/src/Init/Grind/CommRing/Poly.lean @@ -640,6 +640,22 @@ theorem Expr.eq_of_toPoly_eq {α} [CommRing α] (ctx : Context α) (a b : Expr) simp [denote_toPoly] at h assumption +def ne_unsat_cert (a b : Expr) : Bool := + (a.sub b).toPoly == .num 0 + +theorem ne_unsat {α} [CommRing α] (ctx : Context α) (a b : Expr) + : ne_unsat_cert a b → a.denote ctx ≠ b.denote ctx → False := by + simp [ne_unsat_cert] + intro h + replace h := congrArg (Poly.denote ctx .) h + simp [Poly.denote, Expr.denote, Expr.denote_toPoly, intCast_zero] at h + replace h := eq_of_sub_eq_zero h + assumption + +/-! +Theorems for justifying the procedure for commutative rings with a characteristic in `grind`. +-/ + theorem Poly.denote_addConstC {α c} [CommRing α] [IsCharP α c] (ctx : Context α) (p : Poly) (k : Int) : (addConstC p k c).denote ctx = p.denote ctx + k := by fun_induction addConstC <;> simp [addConstC, denote, *] next => rw [IsCharP.intCast_emod, intCast_add] @@ -758,5 +774,17 @@ theorem Expr.eq_of_toPolyC_eq {α c} [CommRing α] [IsCharP α c] (ctx : Context simp [denote_toPolyC] at h assumption +def ne_unsatC_cert (a b : Expr) (c : Nat) : Bool := + (a.sub b).toPolyC c == .num 0 + +theorem ne_unsatC {α c} [CommRing α] [IsCharP α c] (ctx : Context α) (a b : Expr) + : ne_unsatC_cert a b c → a.denote ctx ≠ b.denote ctx → False := by + simp [ne_unsatC_cert] + intro h + replace h := congrArg (Poly.denote ctx .) h + simp [Poly.denote, Expr.denote, Expr.denote_toPolyC, intCast_zero] at h + replace h := eq_of_sub_eq_zero h + assumption + end CommRing end Lean.Grind diff --git a/src/Init/Grind/Tactics.lean b/src/Init/Grind/Tactics.lean index 63173f8f20..fc2ae8e2b7 100644 --- a/src/Init/Grind/Tactics.lean +++ b/src/Init/Grind/Tactics.lean @@ -112,6 +112,10 @@ structure Config where That is, `let x := v; e[x]` reduces to `e[v]`. See also `zetaDelta`. -/ zeta := true + /-- + When `true` (default: `false`), uses procedure for handling equalities over commutative rings. + -/ + ring := false deriving Inhabited, BEq end Lean.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean index 97ccff2b2a..5a446357d8 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean @@ -4,4 +4,21 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude +import Lean.Util.Trace import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly +import Lean.Meta.Tactic.Grind.Arith.CommRing.Types +import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId +import Lean.Meta.Tactic.Grind.Arith.CommRing.Internalize +import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr +import Lean.Meta.Tactic.Grind.Arith.CommRing.Var +import Lean.Meta.Tactic.Grind.Arith.CommRing.Reify +import Lean.Meta.Tactic.Grind.Arith.CommRing.Eq +import Lean.Meta.Tactic.Grind.Arith.CommRing.Proof + +namespace Lean + +builtin_initialize registerTraceClass `grind.ring +builtin_initialize registerTraceClass `grind.ring.internalize +builtin_initialize registerTraceClass `grind.ring.assert + +end Lean diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Eq.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Eq.lean new file mode 100644 index 0000000000..808ca9684f --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Eq.lean @@ -0,0 +1,42 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId +import Lean.Meta.Tactic.Grind.Arith.CommRing.Proof + +namespace Lean.Meta.Grind.Arith.CommRing + +private def toRingExpr? (e : Expr) (ringId : Nat) : GoalM (Option RingExpr) := do + let ring ← getRing ringId + if let some re := ring.denote.find? { expr := e } then + return some re + else if let some x := ring.varMap.find? { expr := e } then + return some (.var x) + else + reportIssue! "failed to convert to ring expression{indentExpr e}" + return none + +@[export lean_process_ring_eq] +def processNewEqImpl (a b : Expr) : GoalM Unit := do + if isSameExpr a b then return () -- TODO: check why this is needed + trace[grind.ring] "{← mkEq a b}" + -- TODO + +@[export lean_process_ring_diseq] +def processNewDiseqImpl (a b : Expr) : GoalM Unit := do + let some ringId ← getTermRingId? a | return () + let some ringId' ← getTermRingId? b | return () + unless ringId == ringId' do return () -- This can happen when we have heterogeneous equalities + trace[grind.ring] "{mkNot (← mkEq a b)}" + let some e₁ ← toRingExpr? a ringId | return () + let some e₂ ← toRingExpr? b ringId | return () + let p ← toPoly (e₁.sub e₂) ringId + if p == .num 0 then + setNeUnsat ringId a b e₁ e₂ + return () + -- TODO: save disequalitys + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Internalize.lean new file mode 100644 index 0000000000..dbf586c859 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Internalize.lean @@ -0,0 +1,45 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Simp +import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId +import Lean.Meta.Tactic.Grind.Arith.CommRing.Reify + +namespace Lean.Meta.Grind.Arith.CommRing + +/-- If `e` is a function application supported by the `CommRing` module, return its type. -/ +private def getType? (e : Expr) : Option Expr := + match_expr e with + | HAdd.hAdd α _ _ _ _ _ => some α + | HSub.hSub α _ _ _ _ _ => some α + | HMul.hMul α _ _ _ _ _ => some α + | HPow.hPow α β _ _ _ _ => + let_expr Nat := β | none + some α + | Neg.neg α _ _ => some α + | OfNat.ofNat α _ _ => some α + | NatCast.natCast α _ _ => some α + | IntCast.intCast α _ _ => some α + | _ => none + +private def isForbiddenParent (parent? : Option Expr) : Bool := + if let some parent := parent? then + getType? parent |>.isSome + else + false + +def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do + unless (← getConfig).ring do return () + let some type := getType? e | return () + if isForbiddenParent parent? then return () + let some ringId ← getRingId? type | return () + let some re ← reify? e ringId | return () + trace[grind.ring.internalize] "[{ringId}]: {e}" + setTermRingId e ringId + markAsCommRingTerm e + modifyRing ringId fun s => { s with denote := s.denote.insert { expr := e } re } + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean new file mode 100644 index 0000000000..1ad78ae08f --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean @@ -0,0 +1,37 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Diseq +import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId +import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr + +namespace Lean.Meta.Grind.Arith.CommRing + +/-- +Returns a context of type `RArray α` containing the variables of the given ring. +`α` is the type of the ring. +-/ +def toContextExpr (ringId : Nat) : GoalM Expr := do + let ring ← getRing ringId + if h : 0 < ring.vars.size then + RArray.toExpr ring.type id (RArray.ofFn (ring.vars[·]) h) + else + RArray.toExpr ring.type id (RArray.leaf (mkApp ring.natCastFn (toExpr 0))) + +private def mkLemmaPrefix (ringId : Nat) (declName declNameC : Name) : GoalM Expr := do + let ring ← getRing ringId + let ctx ← toContextExpr ringId + if let some (charInst, c) ← nonzeroCharInst? ringId then + return mkApp5 (mkConst declNameC [ring.u]) ring.type (toExpr c) ring.commRingInst charInst ctx + else + return mkApp3 (mkConst declName [ring.u]) ring.type ring.commRingInst ctx + +def setNeUnsat (ringId : Nat) (a b : Expr) (ra rb : RingExpr) : GoalM Unit := do + trace[grind.ring.assert] "unsat diseq {a}, {b}" + let h ← mkLemmaPrefix ringId ``Grind.CommRing.ne_unsat ``Grind.CommRing.ne_unsatC + closeGoal <| mkApp4 h (toExpr ra) (toExpr rb) reflBoolTrue (← mkDiseqProof a b) + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Reify.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Reify.lean new file mode 100644 index 0000000000..14718bff24 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Reify.lean @@ -0,0 +1,109 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Simp +import Lean.Meta.Tactic.Grind.Arith.CommRing.Util +import Lean.Meta.Tactic.Grind.Arith.CommRing.Var + +namespace Lean.Meta.Grind.Arith.CommRing + +def isAddInst (ring : Ring) (inst : Expr) : Bool := + isSameExpr ring.addFn.appArg! inst +def isMulInst (ring : Ring) (inst : Expr) : Bool := + isSameExpr ring.mulFn.appArg! inst +def isSubInst (ring : Ring) (inst : Expr) : Bool := + isSameExpr ring.subFn.appArg! inst +def isNegInst (ring : Ring) (inst : Expr) : Bool := + isSameExpr ring.negFn.appArg! inst +def isPowInst (ring : Ring) (inst : Expr) : Bool := + isSameExpr ring.powFn.appArg! inst +def isIntCastInst (ring : Ring) (inst : Expr) : Bool := + isSameExpr ring.intCastFn.appArg! inst +def isNatCastInst (ring : Ring) (inst : Expr) : Bool := + isSameExpr ring.natCastFn.appArg! inst + +private def reportAppIssue (e : Expr) : GoalM Unit := do + reportIssue! "comm ring term with unexpected instance{indentExpr e}" + +/-- +Converts a Lean expression `e` in the `CommRing` with id `ringId` into +a `CommRing.Expr` object. +-/ +partial def reify? (e : Expr) (ringId : Nat) : GoalM (Option RingExpr) := do + let ring ← getRing ringId + let toVar (e : Expr) : GoalM RingExpr := do + return .var (← mkVar e ringId) + let asVar (e : Expr) : GoalM RingExpr := do + reportAppIssue e + return .var (← mkVar e ringId) + let rec go (e : Expr) : GoalM RingExpr := do + match_expr e with + | HAdd.hAdd _ _ _ i a b => + if isAddInst ring i then return .add (← go a) (← go b) else asVar e + | HMul.hMul _ _ _ i a b => + if isMulInst ring i then return .mul (← go a) (← go b) else asVar e + | HSub.hSub _ _ _ i a b => + if isSubInst ring i then return .sub (← go a) (← go b) else asVar e + | HPow.hPow _ _ _ i a b => + let some k ← getNatValue? b | toVar e + if isPowInst ring i then return .pow (← go a) k else asVar e + | Neg.neg _ i a => + if isNegInst ring i then return .neg (← go a) else asVar e + | IntCast.intCast _ i e => + if isIntCastInst ring i then + let some k ← getIntValue? e | toVar e + return .num k + else + asVar e + | NatCast.natCast _ i e => + if isNatCastInst ring i then + let some k ← getNatValue? e | toVar e + return .num k + else + asVar e + | OfNat.ofNat _ n _ => + let some k ← getNatValue? n | toVar e + if (← withDefault <| isDefEq e (mkApp ring.natCastFn n)) then + return .num k + else + asVar e + | _ => toVar e + let asNone (e : Expr) : GoalM (Option RingExpr) := do + reportAppIssue e + return none + match_expr e with + | HAdd.hAdd _ _ _ i a b => + if isAddInst ring i then return some (.add (← go a) (← go b)) else asNone e + | HMul.hMul _ _ _ i a b => + if isMulInst ring i then return some (.mul (← go a) (← go b)) else asNone e + | HSub.hSub _ _ _ i a b => + if isSubInst ring i then return some (.sub (← go a) (← go b)) else asNone e + | HPow.hPow _ _ _ i a b => + let some k ← getNatValue? b | return none + if isPowInst ring i then return some (.pow (← go a) k) else asNone e + | Neg.neg _ i a => + if isNegInst ring i then return some (.neg (← go a)) else asNone e + | IntCast.intCast _ i e => + if isIntCastInst ring i then + let some k ← getIntValue? e | return none + return some (.num k) + else + asNone e + | NatCast.natCast _ i e => + if isNatCastInst ring i then + let some k ← getNatValue? e | return none + return some (.num k) + else + asNone e + | OfNat.ofNat _ n _ => + let some k ← getNatValue? n | return none + if (← withDefault <| isDefEq e (mkApp ring.natCastFn n)) then + return some (.num k) + else + asNone e + | _ => return none + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean new file mode 100644 index 0000000000..6e3c36f338 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean @@ -0,0 +1,123 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Simp +import Lean.Meta.Tactic.Grind.Arith.CommRing.Util + +namespace Lean.Meta.Grind.Arith.CommRing + +private def internalizeFn (fn : Expr) : GoalM Expr := do + shareCommon (← canon fn) + +private def getAddFn (type : Expr) (u : Level) (commRingInst : Expr) : GoalM Expr := do + let instType := mkApp3 (mkConst ``HAdd [u, u, u]) type type type + let .some inst ← trySynthInstance instType | + throwError "failed to find instance for ring addition{indentExpr instType}" + let inst' := mkApp2 (mkConst ``instHAdd [u]) type <| mkApp2 (mkConst ``Grind.CommRing.toAdd [u]) type commRingInst + unless (← withDefault <| isDefEq inst inst') do + throwError "instance for addition{indentExpr inst}\nis not definitionally equal to the `Grind.CommRing` one{indentExpr inst'}" + internalizeFn <| mkApp4 (mkConst ``HAdd.hAdd [u, u, u]) type type type inst + +private def getMulFn (type : Expr) (u : Level) (commRingInst : Expr) : GoalM Expr := do + let instType := mkApp3 (mkConst ``HMul [u, u, u]) type type type + let .some inst ← trySynthInstance instType | + throwError "failed to find instance for ring multiplication{indentExpr instType}" + let inst' := mkApp2 (mkConst ``instHMul [u]) type <| mkApp2 (mkConst ``Grind.CommRing.toMul [u]) type commRingInst + unless (← withDefault <| isDefEq inst inst') do + throwError "instance for multiplication{indentExpr inst}\nis not definitionally equal to the `Grind.CommRing` one{indentExpr inst'}" + internalizeFn <| mkApp4 (mkConst ``HMul.hMul [u, u, u]) type type type inst + +private def getSubFn (type : Expr) (u : Level) (commRingInst : Expr) : GoalM Expr := do + let instType := mkApp3 (mkConst ``HSub [u, u, u]) type type type + let .some inst ← trySynthInstance instType | + throwError "failed to find instance for ring subtraction{indentExpr instType}" + let inst' := mkApp2 (mkConst ``instHSub [u]) type <| mkApp2 (mkConst ``Grind.CommRing.toSub [u]) type commRingInst + unless (← withDefault <| isDefEq inst inst') do + throwError "instance for subtraction{indentExpr inst}\nis not definitionally equal to the `Grind.CommRing` one{indentExpr inst'}" + internalizeFn <| mkApp4 (mkConst ``HSub.hSub [u, u, u]) type type type inst + +private def getNegFn (type : Expr) (u : Level) (commRingInst : Expr) : GoalM Expr := do + let instType := mkApp (mkConst ``Neg [u]) type + let .some inst ← trySynthInstance instType | + throwError "failed to find instance for ring negation{indentExpr instType}" + let inst' := mkApp2 (mkConst ``Grind.CommRing.toNeg [u]) type commRingInst + unless (← withDefault <| isDefEq inst inst') do + throwError "instance for negation{indentExpr inst}\nis not definitionally equal to the `Grind.CommRing` one{indentExpr inst'}" + internalizeFn <| mkApp2 (mkConst ``Neg.neg [u]) type inst + +private def getPowFn (type : Expr) (u : Level) (commRingInst : Expr) : GoalM Expr := do + let instType := mkApp3 (mkConst ``HPow [u, 0, u]) type Nat.mkType type + let .some inst ← trySynthInstance instType | + throwError "failed to find instance for ring power operator{indentExpr instType}" + let inst' := mkApp2 (mkConst ``Grind.CommRing.toHPow [u]) type commRingInst + unless (← withDefault <| isDefEq inst inst') do + throwError "instance for power operator{indentExpr inst}\nis not definitionally equal to the `Grind.CommRing` one{indentExpr inst'}" + internalizeFn <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst + +private def getIntCastFn (type : Expr) (u : Level) (_commRingInst : Expr) : GoalM Expr := do + let instType := mkApp (mkConst ``IntCast [u]) type + let .some inst ← trySynthInstance instType | + throwError "failed to find instance for ring intCast{indentExpr instType}" + -- TODO uncomment after we fix `CommRing` definition + /- + let inst' := mkApp2 (mkConst ``Grind.CommRing.intCastInst [u]) type commRingInst + unless (← withDefault <| isDefEq inst inst') do + throwError "instance for intCast{indentExpr inst}\nis not definitionally equal to the `Grind.CommRing` one{indentExpr inst'}" + -/ + internalizeFn <| mkApp2 (mkConst ``IntCast.intCast [u]) type inst + +private def getNatCastFn (type : Expr) (u : Level) (commRingInst : Expr) : GoalM Expr := do + let instType := mkApp (mkConst ``NatCast [u]) type + let .some inst ← trySynthInstance instType | + throwError "failed to find instance for ring natCast{indentExpr instType}" + let inst' := mkApp2 (mkConst ``Grind.CommRing.natCastInst [u]) type commRingInst + unless (← withDefault <| isDefEq inst inst') do + throwError "instance for natCast{indentExpr inst}\nis not definitionally equal to the `Grind.CommRing` one{indentExpr inst'}" + internalizeFn <| mkApp2 (mkConst ``NatCast.natCast [u]) type inst + +/-- +Returns the ring id for the given type if there is a `CommRing` instance for it. + +This function will also perform sanity-checks +(e.g., the `Add` instance for `type` must be definitionally equal to the `CommRing.toAdd` one.) + +It also caches the functions representing `+`, `*`, `-`, `^`, and `intCast`. +-/ +def getRingId? (type : Expr) : GoalM (Option Nat) := do + if let some id? := (← get').typeIdOf.find? { expr := type } then + return id? + else + let id? ← go? + modify' fun s => { s with typeIdOf := s.typeIdOf.insert { expr := type } id? } + return id? +where + go? : GoalM (Option Nat) := do + let u ← getDecLevel type + let ring := mkApp (mkConst ``Grind.CommRing [u]) type + let .some commRingInst ← trySynthInstance ring | return none + trace[grind.ring] "new ring: {type}" + let charInst? ← withNewMCtxDepth do + let n ← mkFreshExprMVar (mkConst ``Nat) + let charType := mkApp3 (mkConst ``Grind.IsCharP [u]) type commRingInst n + let .some charInst ← trySynthInstance charType | pure none + let n ← instantiateMVars n + let some n ← evalNat n |>.run + | trace[grind.ring] "found instance for{indentExpr charType}\nbut characteristic is not a natural number"; pure none + trace[grind.ring] "characteristic: {n}" + pure <| some (charInst, n) + let addFn ← getAddFn type u commRingInst + let mulFn ← getMulFn type u commRingInst + let subFn ← getSubFn type u commRingInst + let negFn ← getNegFn type u commRingInst + let powFn ← getPowFn type u commRingInst + let intCastFn ← getIntCastFn type u commRingInst + let natCastFn ← getNatCastFn type u commRingInst + let id := (← get').rings.size + let ring : Ring := { type, u, commRingInst, charInst?, addFn, mulFn, subFn, negFn, powFn, intCastFn, natCastFn } + modify' fun s => { s with rings := s.rings.push ring } + return some id + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean new file mode 100644 index 0000000000..698affe3d5 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean @@ -0,0 +1,57 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Grind.CommRing.Poly +import Lean.ToExpr + +namespace Lean.Meta.Grind.Arith.CommRing +open Grind.CommRing +/-! +`ToExpr` instances for `CommRing.Poly` types. +-/ + +def ofPower (p : Power) : Expr := + mkApp2 (mkConst ``Power.mk) (toExpr p.x) (toExpr p.k) + +instance : ToExpr Power where + toExpr := ofPower + toTypeExpr := mkConst ``Power + +def ofMon (m : Mon) : Expr := + match m with + | .unit => mkConst ``Mon.unit + | .mult pw m => mkApp2 (mkConst ``Mon.mult) (toExpr pw) (ofMon m) + +instance : ToExpr Mon where + toExpr := ofMon + toTypeExpr := mkConst ``Mon + +def ofPoly (p : Poly) : Expr := + match p with + | .num k => mkApp (mkConst ``Poly.num) (toExpr k) + | .add k m p => mkApp3 (mkConst ``Poly.add) (toExpr k) (toExpr m) (ofPoly p) + +instance : ToExpr Poly where + toExpr := ofPoly + toTypeExpr := mkConst ``Poly + +open Lean.Grind + +def ofRingExpr (e : CommRing.Expr) : Expr := + match e with + | .num k => mkApp (mkConst ``CommRing.Expr.num) (toExpr k) + | .var x => mkApp (mkConst ``CommRing.Expr.var) (toExpr x) + | .add a b => mkApp2 (mkConst ``CommRing.Expr.add) (ofRingExpr a) (ofRingExpr b) + | .mul a b => mkApp2 (mkConst ``CommRing.Expr.mul) (ofRingExpr a) (ofRingExpr b) + | .sub a b => mkApp2 (mkConst ``CommRing.Expr.sub) (ofRingExpr a) (ofRingExpr b) + | .neg a => mkApp (mkConst ``CommRing.Expr.neg) (ofRingExpr a) + | .pow a k => mkApp2 (mkConst ``CommRing.Expr.pow) (ofRingExpr a) (toExpr k) + +instance : ToExpr CommRing.Expr where + toExpr := ofRingExpr + toTypeExpr := mkConst ``CommRing.Expr + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean new file mode 100644 index 0000000000..fa57a3b3d9 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean @@ -0,0 +1,60 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Data.PersistentArray +import Lean.Meta.Tactic.Grind.ENodeKey +import Lean.Meta.Tactic.Grind.Arith.Util +import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly + +namespace Lean.Meta.Grind.Arith.CommRing +export Lean.Grind.CommRing (Var Power Mon Poly) +abbrev RingExpr := Grind.CommRing.Expr + +deriving instance Repr for Power, Mon, Poly + +/-- State for each `CommRing` processed by this module. -/ +structure Ring where + type : Expr + /-- Cached `getDecLevel type` -/ + u : Level + /-- `CommRing` instance for `type` -/ + commRingInst : Expr + /-- `IsCharP` instance for `type` if available. -/ + charInst? : Option (Expr × Nat) := .none + addFn : Expr + mulFn : Expr + subFn : Expr + negFn : Expr + powFn : Expr + intCastFn : Expr + natCastFn : Expr + /-- + Mapping from variables to their denotations. + Remark each variable can be in only one ring. + -/ + vars : PArray Expr := {} + /-- Mapping from `Expr` to a variable representing it. -/ + varMap : PHashMap ENodeKey Var := {} + /-- Mapping from Lean expressions to their representations as `RingExpr` -/ + denote : PHashMap ENodeKey RingExpr := {} + deriving Inhabited + +/-- State for all `CommRing` types detected by `grind`. -/ +structure State where + /-- + Commutative rings. + We expect to find a small number of rings in a given goal. Thus, using `Array` is fine here. + -/ + rings : Array Ring := {} + /-- + Mapping from types to its "ring id". We cache failures using `none`. + `typeIdOf[type]` is `some id`, then `id < rings.size`. -/ + typeIdOf : PHashMap ENodeKey (Option Nat) := {} + /- Mapping from expressions/terms to their ring ids. -/ + exprToRingId : PHashMap ENodeKey Nat := {} + deriving Inhabited + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean new file mode 100644 index 0000000000..76fe005669 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Util.lean @@ -0,0 +1,63 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Types + +namespace Lean.Meta.Grind.Arith.CommRing + +def get' : GoalM State := do + return (← get).arith.ring + +@[inline] def modify' (f : State → State) : GoalM Unit := do + modify fun s => { s with arith.ring := f s.arith.ring } + +def getRing (ringId : Nat) : GoalM Ring := do + let s ← get' + if h : ringId < s.rings.size then + return s.rings[ringId] + else + throwError "`grind` internal error, invalid ringId" + +@[inline] def modifyRing (ringId : Nat) (f : Ring → Ring) : GoalM Unit := do + modify' fun s => { s with rings := s.rings.modify ringId f } + +def getTermRingId? (e : Expr) : GoalM (Option Nat) := do + return (← get').exprToRingId.find? { expr := e } + +def setTermRingId (e : Expr) (ringId : Nat) : GoalM Unit := do + if let some ringId' ← getTermRingId? e then + unless ringId' == ringId do + reportIssue! "expression in two different rings{indentExpr e}" + return () + modify' fun s => { s with exprToRingId := s.exprToRingId.insert { expr := e } ringId } + +/-- Returns `some c` if the given ring has a nonzero characteristic `c`. -/ +def nonzeroChar? (ringId : Nat) : GoalM (Option Nat) := do + let ring ← getRing ringId + if let some (_, c) := ring.charInst? then + if c != 0 then + return some c + return none + +/-- Returns `some (charInst, c)` if the given ring has a nonzero characteristic `c`. -/ +def nonzeroCharInst? (ringId : Nat) : GoalM (Option (Expr × Nat)) := do + let ring ← getRing ringId + if let some (inst, c) := ring.charInst? then + if c != 0 then + return some (inst, c) + return none + +/-- +Converts the given ring expression into a multivariate polynomial. +If the ring has a nonzero characteristic, it is used during normalization. +-/ +def toPoly (e : RingExpr) (ringId : Nat) : GoalM Poly := do + if let some c ← nonzeroChar? ringId then + return e.toPolyC c + else + return e.toPoly + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Var.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Var.lean new file mode 100644 index 0000000000..0a84ddc5af --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Var.lean @@ -0,0 +1,24 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Arith.CommRing.Util + +namespace Lean.Meta.Grind.Arith.CommRing + +def mkVar (e : Expr) (ringId : Nat) : GoalM Var := do + let s ← getRing ringId + if let some var := s.varMap.find? { expr := e } then + return var + let var : Var := s.vars.size + modifyRing ringId fun s => { s with + vars := s.vars.push e + varMap := s.varMap.insert { expr := e } var + } + setTermRingId e ringId + markAsCommRingTerm e + return var + +end Lean.Meta.Grind.Arith.CommRing diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean index 018cc1a89f..1bbad42089 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean @@ -64,11 +64,6 @@ def mkNatExprDecl (e : Int.OfNat.Expr) : ProofM Expr := do modify fun s => { s with natExprMap := s.natExprMap.insert e x } return x -private def mkDiseqProof (a b : Expr) : GoalM Expr := do - let some h ← mkDiseqProof? a b - | throwError "internal `grind` error, failed to build disequality proof for{indentExpr a}\nand{indentExpr b}" - return h - private def mkLetOfMap {_ : Hashable α} {_ : BEq α} (m : Std.HashMap α Expr) (e : Expr) (varPrefix : Name) (varType : Expr) (toExpr : α → Expr) : GoalM Expr := do if m.isEmpty then diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean index 5426f79220..6bb5aef2c2 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean @@ -6,11 +6,13 @@ Authors: Leonardo de Moura prelude import Lean.Meta.Tactic.Grind.Arith.Offset import Lean.Meta.Tactic.Grind.Arith.Cutsat.EqCnstr +import Lean.Meta.Tactic.Grind.Arith.CommRing.Internalize namespace Lean.Meta.Grind.Arith def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do Offset.internalize e parent? Cutsat.internalize e parent? + CommRing.internalize e parent? end Lean.Meta.Grind.Arith diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/Types.lean index bc37c82dc0..a9e786cff7 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Types.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Lean.Meta.Tactic.Grind.Arith.Offset.Types import Lean.Meta.Tactic.Grind.Arith.Cutsat.Types +import Lean.Meta.Tactic.Grind.Arith.CommRing.Types namespace Lean.Meta.Grind.Arith @@ -13,6 +14,7 @@ namespace Lean.Meta.Grind.Arith structure State where offset : Offset.State := {} cutsat : Cutsat.State := {} + ring : CommRing.State := {} deriving Inhabited end Lean.Meta.Grind.Arith diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index 861779a7a9..269e6c32c3 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -111,7 +111,7 @@ private def propagateOffsetEq (rhsRoot lhsRoot : ENode) : GoalM Unit := do /-- Helper function for combining `ENode.cutsat?` fields and propagating equalities -to the offset constraint module. +to the cutsat module. It returns a set of parents that should be traversed for disequality propagation. -/ private def propagateCutsatEq (rhsRoot lhsRoot : ENode) : GoalM ParentSet := do @@ -138,6 +138,28 @@ private def propagateCutsatEq (rhsRoot lhsRoot : ENode) : GoalM ParentSet := do else return {} +/-- +Helper function for combining `ENode.ring?` fields and propagating equalities +to the commutative ring module. +It returns a set of parents that should be traversed for disequality propagation. +-/ +private def propagateCommRingEq (rhsRoot lhsRoot : ENode) : GoalM ParentSet := do + match lhsRoot.ring? with + | some lhsRing => + if let some rhsRing := rhsRoot.ring? then + Arith.CommRing.processNewEq lhsRing rhsRing + return {} + else + -- We have to retrieve the node because other fields have been updated + let rhsRoot ← getENode rhsRoot.self + setENode rhsRoot.self { rhsRoot with ring? := lhsRing } + getParents rhsRoot.self + | none => + if rhsRoot.ring?.isSome then + getParents lhsRoot.self + else + return {} + /-- Tries to apply beta-reductiong using the parent applications of the functions in `fns` with the lambda expressions in `lams`. @@ -241,7 +263,8 @@ where propagateBeta lams₁ fns₁ propagateBeta lams₂ fns₂ propagateOffsetEq rhsRoot lhsRoot - let parentsToPropagateDiseqs ← propagateCutsatEq rhsRoot lhsRoot + let parentsToPropagateCutsatDiseqs ← propagateCutsatEq rhsRoot lhsRoot + let parentsToPropagateRingDiseqs ← propagateCommRingEq rhsRoot lhsRoot resetParentsOf lhsRoot.self copyParentsTo parents rhsNode.root unless (← isInconsistent) do @@ -251,8 +274,8 @@ where propagateUp parent for e in toPropagateDown do propagateDown e - propagateCutsatDiseqs parentsToPropagateDiseqs - + propagateCutsatDiseqs parentsToPropagateCutsatDiseqs + propagateCommRingDiseqs parentsToPropagateRingDiseqs updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do traverseEqc lhs fun n => setENode n.self { n with root := rootNew } diff --git a/src/Lean/Meta/Tactic/Grind/Diseq.lean b/src/Lean/Meta/Tactic/Grind/Diseq.lean index 1f5d223ae0..89275f0969 100644 --- a/src/Lean/Meta/Tactic/Grind/Diseq.lean +++ b/src/Lean/Meta/Tactic/Grind/Diseq.lean @@ -78,4 +78,9 @@ def mkDiseqProof? (a b : Expr) : GoalM (Option Expr) := do else return mkApp6 (mkConst ``Grind.ne_of_ne_of_eq_right u) α b a d (← mkEqProof b d) h +def mkDiseqProof (a b : Expr) : GoalM Expr := do + let some h ← mkDiseqProof? a b + | throwError "internal `grind` error, failed to build disequality proof for{indentExpr a}\nand{indentExpr b}" + return h + end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Propagate.lean b/src/Lean/Meta/Tactic/Grind/Propagate.lean index 16f06e0903..636b2242e2 100644 --- a/src/Lean/Meta/Tactic/Grind/Propagate.lean +++ b/src/Lean/Meta/Tactic/Grind/Propagate.lean @@ -160,6 +160,7 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do propagateBoolDiseq lhs propagateBoolDiseq rhs propagateCutsatDiseq lhs rhs + propagateCommRingDiseq lhs rhs let thms ← getExtTheorems α if !thms.isEmpty then /- diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index b59a281746..983406f656 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -323,6 +323,11 @@ structure ENode where to the cutsat module. Its implementation is similar to the `offset?` field. -/ cutsat? : Option Expr := none + /-- + The `ring?` field is used to propagate equalities from the `grind` congruence closure module + to the comm ring module. Its implementation is similar to the `offset?` field. + -/ + ring? : Option Expr := none -- Remark: we expect to have builtin support for offset constraints, cutsat, and comm ring. -- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3. deriving Inhabited, Repr @@ -1015,6 +1020,53 @@ def markAsCutsatTerm (e : Expr) : GoalM Unit := do setENode root.self { root with cutsat? := some e } propagateCutsatDiseqs (← getParents root.self) +/-- +Notifies the comm ring module that `a = b` where +`a` and `b` are terms that have been internalized by this module. +-/ +@[extern "lean_process_ring_eq"] -- forward definition +opaque Arith.CommRing.processNewEq (a b : Expr) : GoalM Unit + +/-- +Notifies the comm ring module that `a ≠ b` where +`a` and `b` are terms that have been internalized by this module. +-/ +@[extern "lean_process_ring_diseq"] -- forward definition +opaque Arith.CommRing.processNewDiseq (a b : Expr) : GoalM Unit + +/-- +Given `lhs` and `rhs` that are known to be disequal, checks whether +`lhs` and `rhs` have ring terms `e₁` and `e₂` attached to them, +and invokes process `Arith.CommRing.processNewDiseq e₁ e₂` +-/ +def propagateCommRingDiseq (lhs rhs : Expr) : GoalM Unit := do + let some lhs ← get? lhs | return () + let some rhs ← get? rhs | return () + Arith.CommRing.processNewDiseq lhs rhs +where + get? (a : Expr) : GoalM (Option Expr) := do + return (← getRootENode a).ring? + +/-- +Traverses disequalities in `parents`, and propagate the ones relevant to the +comm ring module. +-/ +def propagateCommRingDiseqs (parents : ParentSet) : GoalM Unit := do + forEachDiseq parents propagateCommRingDiseq + +/-- +Marks `e` as a term of interest to the ring module. +If the root of `e`s equivalence class has already a term of interest, +a new equality is propagated to the ring module. +-/ +def markAsCommRingTerm (e : Expr) : GoalM Unit := do + let root ← getRootENode e + if let some e' := root.ring? then + Arith.CommRing.processNewEq e e' + else + setENode root.self { root with ring? := some e } + propagateCommRingDiseqs (← getParents root.self) + /-- Returns `true` is `e` is the root of its congruence class. -/ def isCongrRoot (e : Expr) : GoalM Bool := do return (← getENode e).isCongrRoot diff --git a/tests/lean/run/grind_ring_1.lean b/tests/lean/run/grind_ring_1.lean new file mode 100644 index 0000000000..70e62d5982 --- /dev/null +++ b/tests/lean/run/grind_ring_1.lean @@ -0,0 +1,23 @@ +set_option grind.warning false +open Lean.Grind + +example [CommRing α] (x : α) : (x + 1)*(x - 1) = x^2 - 1 := by + grind +ring + +example [CommRing α] [IsCharP α 256] (x : α) : (x + 16)*(x - 16) = x^2 := by + grind +ring + +example (x : Int) : (x + 1)*(x - 1) = x^2 - 1 := by + grind +ring + +example (x : UInt8) : (x + 16)*(x - 16) = x^2 := by + grind +ring + +example (x : Int) : (x + 1)^2 - 1 = x^2 + 2*x := by + grind +ring + +example (x : BitVec 8) : (x + 16)*(x - 16) = x^2 := by + grind +ring + +example (x : BitVec 8) : (x + 1)^2 - 1 = x^2 + 2*x := by + grind +ring