From 6b520ede0863ebb540dabc67dacb41854c5a34d1 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 26 Jun 2025 17:28:51 -0700 Subject: [PATCH] feat: generic `toInt` for cutsat (#9022) This PR completes the generic `toInt` infrastructure for embedding terms implementing the `ToInt` type classes into `Int`. --- src/Init/Grind/ToIntLemmas.lean | 39 ++++- .../Tactic/Grind/Arith/Cutsat/EqCnstr.lean | 15 +- .../Meta/Tactic/Grind/Arith/Cutsat/ToInt.lean | 152 +++++++++++++++--- .../Tactic/Grind/Arith/Cutsat/ToIntInfo.lean | 42 ++++- 4 files changed, 218 insertions(+), 30 deletions(-) diff --git a/src/Init/Grind/ToIntLemmas.lean b/src/Init/Grind/ToIntLemmas.lean index fea63c6031..82fe35451f 100644 --- a/src/Init/Grind/ToIntLemmas.lean +++ b/src/Init/Grind/ToIntLemmas.lean @@ -89,8 +89,45 @@ theorem mul_congr.wl {α i} [ToInt α i] [_root_.Mul α] [ToInt.Mul α i] (h : i have := i.wrap_eq_self_iff h' _ |>.mpr (ToInt.toInt_mem b) rw [h₂] at this; rw [← this] at h₂; apply mul_congr.ww h h₁ h₂ --- TODO: add theorems for other operations +/-! Subtraction -/ +theorem sub_congr {α i} [ToInt α i] [_root_.Sub α] [ToInt.Sub α i] {a b : α} {a' b' : Int} + (h₁ : toInt a = a') (h₂ : toInt b = b') : toInt (a - b) = i.wrap (a' - b') := by + rw [ToInt.Sub.toInt_sub, h₁, h₂] +/-! Negation -/ + +theorem neg_congr {α i} [ToInt α i] [_root_.Neg α] [ToInt.Neg α i] {a : α} {a' : Int} + (h₁ : toInt a = a') : toInt (- a) = i.wrap (- a') := by + rw [ToInt.Neg.toInt_neg, h₁] + +/-! Power -/ + +theorem pow_congr {α i} [ToInt α i] [HPow α Nat α] [ToInt.Pow α i] {a : α} (k : Nat) (a' : Int) + (h₁ : toInt a = a') : toInt (a ^ k) = i.wrap (a' ^ k) := by + rw [ToInt.Pow.toInt_pow, h₁] + +/-! Division -/ + +theorem div_congr {α i} [ToInt α i] [_root_.Div α] [ToInt.Div α i] {a b : α} {a' b' : Int} + (h₁ : toInt a = a') (h₂ : toInt b = b') : toInt (a / b) = a' / b' := by + rw [ToInt.Div.toInt_div, h₁, h₂] + +/-! Modulo -/ + +theorem mod_congr {α i} [ToInt α i] [_root_.Mod α] [ToInt.Mod α i] {a b : α} {a' b' : Int} + (h₁ : toInt a = a') (h₂ : toInt b = b') : toInt (a % b) = a' % b' := by + rw [ToInt.Mod.toInt_mod, h₁, h₂] + +/-! OfNat -/ + +theorem ofNat_eq {α i} [ToInt α i] [∀ n, _root_.OfNat α n] [ToInt.OfNat α i] (n : Nat) + : toInt (OfNat.ofNat (α := α) n) = i.wrap n := by + apply ToInt.OfNat.toInt_ofNat + +/-! Zero -/ + +theorem zero_eq {α i} [ToInt α i] [_root_.Zero α] [ToInt.Zero α i] : toInt (0 : α) = 0 := by + apply ToInt.Zero.toInt_zero end Lean.Grind.ToInt diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean index 0af8329eec..9a05c52e91 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean @@ -382,6 +382,13 @@ private def internalizeNat (e : Expr) : GoalM Unit := do let c := { p := .add (-1) x p, h := .defnNat e' x e'' : EqCnstr } c.assert + +private def isToIntForbiddenParent (parent? : Option Expr) : Bool := + if let some parent := parent? then + getKindAndType? parent |>.isSome + else + false + /-- Internalizes an integer (and `Nat`) expression. Here are the different cases that are handled. @@ -394,14 +401,16 @@ Internalizes an integer (and `Nat`) expression. Here are the different cases tha def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do unless (← getConfig).cutsat do return () let some (k, type) := getKindAndType? e | return () - if isForbiddenParent parent? k then return () - trace[grind.debug.cutsat.internalize] "{e} : {type}" if type.isConstOf ``Int then + if isForbiddenParent parent? k then return () + trace[grind.debug.cutsat.internalize] "{e} : {type}" match k with | .div => propagateDiv e | .mod => propagateMod e | _ => internalizeInt e else if type.isConstOf ``Nat then + if isForbiddenParent parent? k then return () + trace[grind.debug.cutsat.internalize] "{e} : {type}" if (← hasForeignVar e) then return () discard <| mkForeignVar e .nat match k with @@ -410,6 +419,8 @@ def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do | .toNat => propagateToNat e | _ => internalizeNat e else if let some (e', h) ← toInt? e type then + if isToIntForbiddenParent parent? then return () + trace[grind.debug.cutsat.internalize] "{e} : {type}" -- TODO: save `(e', h)` trace[grind.debug.cutsat.toInt] "{e} ==> {e'}" trace[grind.debug.cutsat.toInt] "{h} : {← inferType h}" diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToInt.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToInt.lean index 10b57a68b1..2095736cf1 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToInt.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToInt.lean @@ -19,16 +19,41 @@ private def checkDecl (declName : Name) : MetaM Unit := do unless (← getEnv).contains declName do throwMissingDecl declName -private def mkOpCongr (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) (range : Grind.IntInterval) (opBaseName : Name) (thmName : Name) : MetaM ToIntCongr := do - let op := mkApp (mkConst opBaseName [u]) type - let .some opInst ← trySynthInstance op | return {} - let toIntOpName := ``Grind.ToInt ++ opBaseName +private def mkOfNatThm? (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) : MetaM (Option Expr) := do + -- ∀ n, OfNat α n + let ofNat := mkForall `n .default (mkConst ``Nat) (mkApp2 (mkConst ``OfNat [u]) type (mkBVar 0)) + let .some ofNatInst ← trySynthInstance ofNat + | reportMissingToIntAdapter type ofNat; return none + let toIntOfNat := mkApp4 (mkConst ``Grind.ToInt.OfNat [u]) type ofNatInst rangeExpr toIntInst + let .some toIntOfNatInst ← trySynthInstance toIntOfNat + | reportMissingToIntAdapter type toIntOfNat; return none + return mkApp5 (mkConst ``Grind.ToInt.ofNat_eq [u]) type rangeExpr toIntInst ofNatInst toIntOfNatInst + +/-- Helper function for `mkSimpleOpThm?` and `mkPowThm?` -/ +private def mkSimpleOpThmCore? (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) (op : Expr) (opSuffix : Name) (thmName : Name) : MetaM (Option Expr) := do + let .some opInst ← trySynthInstance op | return none + let toIntOpName := ``Grind.ToInt ++ opSuffix checkDecl toIntOpName let toIntOp := mkApp4 (mkConst toIntOpName [u]) type opInst rangeExpr toIntInst let .some toIntOpInst ← trySynthInstance toIntOp - | reportMissingToIntAdapter type toIntOp; return {} + | reportMissingToIntAdapter type toIntOp; return none checkDecl thmName - let c := mkApp5 (mkConst thmName [u]) type rangeExpr toIntInst opInst toIntOpInst + return mkApp5 (mkConst thmName [u]) type rangeExpr toIntInst opInst toIntOpInst + +/-- Simpler version of `mkBinOpThms` for operators that have only one congruence theorem. -/ +private def mkSimpleOpThm? (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) (opBaseName : Name) (thmName : Name) : MetaM (Option Expr) := do + let op := mkApp (mkConst opBaseName [u]) type + mkSimpleOpThmCore? type u toIntInst rangeExpr op opBaseName thmName + +/-- Simpler version of `mkBinOpThms` for operators that have only one congruence theorem. -/ +private def mkPowThm? (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) : MetaM (Option Expr) := do + let op := mkApp3 (mkConst ``HPow [u, 0, u]) type Nat.mkType type + mkSimpleOpThmCore? type u toIntInst rangeExpr op `Pow ``Grind.ToInt.pow_congr + +private def mkBinOpThms (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) (range : Grind.IntInterval) (opBaseName : Name) (thmName : Name) : MetaM ToIntThms := do + let some c ← mkSimpleOpThm? type u toIntInst rangeExpr opBaseName thmName | return {} + let opInst := c.appFn!.appArg! + let toIntOpInst := c.appArg! let env ← getEnv let cwwName := thmName ++ `ww let cwlName := thmName ++ `wl @@ -118,13 +143,22 @@ where let ofLT := mkApp5 (mkConst ``Grind.ToInt.of_lt [u]) type rangeExpr toIntInst ltInst toIntLTInst let ofNotLT := mkApp5 (mkConst ``Grind.ToInt.of_not_lt [u]) type rangeExpr toIntInst ltInst toIntLTInst pure (some ofLT, some ofNotLT) - let mkOp (opBaseName : Name) (thmName : Name) := - mkOpCongr type u toIntInst rangeExpr range opBaseName thmName - let add ← mkOp ``Add ``Grind.ToInt.add_congr - let mul ← mkOp ``Mul ``Grind.ToInt.mul_congr - -- TODO: other operators + let mkBinOpThms (opBaseName : Name) (thmName : Name) := + mkBinOpThms type u toIntInst rangeExpr range opBaseName thmName + let mkSimpleOpThm? (opBaseName : Name) (thmName : Name) := + mkSimpleOpThm? type u toIntInst rangeExpr opBaseName thmName + let addThms ← mkBinOpThms ``Add ``Grind.ToInt.add_congr + let mulThms ← mkBinOpThms ``Mul ``Grind.ToInt.mul_congr + let subThm? ← mkSimpleOpThm? ``Sub ``Grind.ToInt.sub_congr + let negThm? ← mkSimpleOpThm? ``Neg ``Grind.ToInt.neg_congr + let divThm? ← mkSimpleOpThm? ``Div ``Grind.ToInt.div_congr + let modThm? ← mkSimpleOpThm? ``Mod ``Grind.ToInt.mod_congr + let powThm? ← mkPowThm? type u toIntInst rangeExpr + let zeroThm? ← mkSimpleOpThm? ``Zero ``Grind.ToInt.zero_eq + let ofNatThm? ← mkOfNatThm? type u toIntInst rangeExpr return some { - type, u, toIntInst, rangeExpr, range, toInt, wrap, ofWrap0?, ofEq, ofDiseq, ofLE?, ofNotLE?, ofLT?, ofNotLT?, add, mul + type, u, toIntInst, rangeExpr, range, toInt, wrap, ofWrap0?, ofEq, ofDiseq, ofLE?, ofNotLE?, ofLT?, ofNotLT?, addThms, mulThms, + subThm?, negThm?, divThm?, modThm?, powThm?, zeroThm?, ofNatThm? } structure ToIntM.Context where @@ -171,34 +205,106 @@ private def expandWrap (a b : Expr) (h : Expr) : ToIntM (Expr × Expr) := do return (b', h) | _ => throwError "`grind cutsat`, `ToInt` interval not supported yet" -private def mkToIntResult (toIntOp : ToIntCongr) (mkBinOp : Expr → Expr → Expr) (a b : Expr) (a' b' : Expr) (h₁ h₂ : Expr) : ToIntM (Expr × Expr) := do - let f := toIntOp.c?.get! +/-- +Given `h : toInt a = b`, if `b` is of the form `i.wrap b'`, +invokes `expandWrap a b' h` +-/ +private def expandIfWrap (a b : Expr) (h : Expr) : ToIntM (Expr × Expr) := do + match isWrap b with + | none => return (b, h) + | some b => expandWrap a b h + +private def mkWrap (a : Expr) : ToIntM Expr := do + return mkApp (← getInfo).wrap a + +private def ToIntThms.mkResult (toIntThms : ToIntThms) (mkBinOp : Expr → Expr → Expr) (a b : Expr) (a' b' : Expr) (h₁ h₂ : Expr) : ToIntM (Expr × Expr) := do + let f := toIntThms.c?.get! let mk (f : Expr) (a' b' : Expr) : ToIntM (Expr × Expr) := do + -- If the appropriate `wrap` cancellation theorem is missing, we have to expand the nested wrap. + let (a', h₁) ← expandIfWrap a a' h₁ + let (b', h₂) ← expandIfWrap b b' h₂ let h := mkApp6 f a b a' b' h₁ h₂ - let r := mkApp (← getInfo).wrap (mkBinOp a' b') + let r ← mkWrap (mkBinOp a' b') return (r, h) match isWrap a', isWrap b' with | none, none => mk f a' b' - | some a'', none => if let some f := toIntOp.c_wl? then mk f a'' b' else mk f a' b' - | none, some b'' => if let some f := toIntOp.c_wr? then mk f a' b'' else mk f a' b' - | some a'', some b'' => if let some f := toIntOp.c_ww? then mk f a'' b'' else mk f a' b' + | some a'', none => if let some f := toIntThms.c_wl? then mk f a'' b' else mk f a' b' + | none, some b'' => if let some f := toIntThms.c_wr? then mk f a' b'' else mk f a' b' + | some a'', some b'' => if let some f := toIntThms.c_ww? then mk f a'' b'' else mk f a' b' partial def toInt (e : Expr) : ToIntM (Expr × Expr) := do match_expr e with | HAdd.hAdd α β γ _ a b => unless isHomo α β γ do return (← toIntDef e) - toIntBin (← getInfo).add mkIntAdd a b + toIntBin (← getInfo).addThms mkIntAdd a b | HMul.hMul α β γ _ a b => unless isHomo α β γ do return (← toIntDef e) - toIntBin (← getInfo).mul mkIntMul a b - -- TODO: other operators + toIntBin (← getInfo).mulThms mkIntMul a b + | HDiv.hDiv α β γ _ a b => + unless isHomo α β γ do return (← toIntDef e) + processDivMod (isDiv := true) a b + | HMod.hMod α β γ _ a b => + unless isHomo α β γ do return (← toIntDef e) + processDivMod (isDiv := false) a b + | HSub.hSub α β γ _ a b => + unless isHomo α β γ do return (← toIntDef e) + processSub a b + | Neg.neg _ _ a => + processNeg a + | HPow.hPow α β γ _ a b => + unless isSameExpr α γ && β.isConstOf ``Nat do return (← toIntDef e) + processPow a b + | Zero.zero _ _ => + let some thm := (← getInfo).zeroThm? | toIntDef e + return (mkIntLit 0, thm) + | OfNat.ofNat _ n _ => + let some thm := (← getInfo).ofNatThm? | toIntDef e + let some n ← getNatValue? n | toIntDef e + let r := mkIntLit ((← getInfo).range.wrap n) + let h := mkApp thm (toExpr n) + return (r, h) | _ => toIntDef e where - toIntBin (toIntOp : ToIntCongr) (mkBinOp : Expr → Expr → Expr) (a b : Expr) : ToIntM (Expr × Expr) := do + toIntBin (toIntOp : ToIntThms) (mkBinOp : Expr → Expr → Expr) (a b : Expr) : ToIntM (Expr × Expr) := do unless toIntOp.c?.isSome do return (← toIntDef e) let (a', h₁) ← toInt a let (b', h₂) ← toInt b - mkToIntResult toIntOp mkBinOp a b a' b' h₁ h₂ + toIntOp.mkResult mkBinOp a b a' b' h₁ h₂ + + toIntAndExpandWrap (a : Expr) : ToIntM (Expr × Expr) := do + let (a', h₁) ← toInt a + expandIfWrap a a' h₁ + + processDivMod (isDiv : Bool) (a b : Expr) : ToIntM (Expr × Expr) := do + let some thm ← if isDiv then pure (← getInfo).divThm? else pure (← getInfo).modThm? + | return (← toIntDef e) + let (a', h₁) ← toIntAndExpandWrap a + let (b', h₂) ← toIntAndExpandWrap b + let r := if isDiv then mkIntDiv a' b' else mkIntMod a' b' + let h := mkApp6 thm a b a' b' h₁ h₂ + return (r, h) + + processSub (a b : Expr) : ToIntM (Expr × Expr) := do + let some thm := (← getInfo).subThm? | return (← toIntDef e) + let (a', h₁) ← toIntAndExpandWrap a + let (b', h₂) ← toIntAndExpandWrap b + let r ← mkWrap (mkIntSub a' b') + let h := mkApp6 thm a b a' b' h₁ h₂ + return (r, h) + + processNeg (a : Expr) : ToIntM (Expr × Expr) := do + let some thm := (← getInfo).negThm? | return (← toIntDef e) + let (a', h₁) ← toIntAndExpandWrap a + let r ← mkWrap (mkIntNeg a') + let h := mkApp3 thm a a' h₁ + return (r, h) + + processPow (a b : Expr) : ToIntM (Expr × Expr) := do + let some thm := (← getInfo).powThm? | return (← toIntDef e) + let (a', h₁) ← toIntAndExpandWrap a + let r ← mkWrap (mkIntPowNat a' b) + let h := mkApp4 thm a b a' h₁ + return (r, h) def toInt? (a : Expr) (type : Expr) : GoalM (Option (Expr × Expr)) := do ToIntM.run? type do diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToIntInfo.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToIntInfo.lean index f3d6b83011..5b4b002807 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToIntInfo.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToIntInfo.lean @@ -10,10 +10,38 @@ import Lean.Meta.Tactic.Grind.Arith.Util namespace Lean.Meta.Grind.Arith.Cutsat open Lean Grind -structure ToIntCongr where +/-- +Theorems for operators that have support for `i.wrap` over `i.wrap` simplification. +Currently only addition and multiplication have `wrap` cancellation theorems +-/ +structure ToIntThms where + /-- + Basic theorem of the form + ``` + toInt a = a' → toInt b = b' → toInt (a ⊞ b) = i.wrap (a' ⊞ b')` + ``` + -/ c? : Option Expr := none + /-- + Left-right `wrap` cancellation theorem of the form + ``` + toInt a = i.wrap a' → toInt b = i.wrap b' → toInt (a ⊞ b) = i.wrap (a' ⊞ b') + ``` + -/ c_ww? : Option Expr := none + /-- + Left `wrap` cancellation theorem of the form + ``` + toInt a = i.wrap a' → toInt b = b' → toInt (a ⊞ b) = i.wrap (a' ⊞ b') + ``` + -/ c_wl? : Option Expr := none + /-- + Right `wrap` cancellation theorem of the form + ``` + toInt a = a' → toInt b = i.wrap b' → toInt (a ⊞ b) = i.wrap (a' ⊞ b') + ``` + -/ c_wr? : Option Expr := none structure ToIntInfo where @@ -32,8 +60,14 @@ structure ToIntInfo where ofNotLE? : Option Expr ofLT? : Option Expr ofNotLT? : Option Expr - add : ToIntCongr - mul : ToIntCongr - -- TODO: other operators + addThms : ToIntThms + mulThms : ToIntThms + subThm? : Option Expr + negThm? : Option Expr + divThm? : Option Expr + modThm? : Option Expr + powThm? : Option Expr + zeroThm? : Option Expr + ofNatThm? : Option Expr end Lean.Meta.Grind.Arith.Cutsat