feat: generic toInt for cutsat (#9022)

This PR completes the generic `toInt` infrastructure for embedding terms
implementing the `ToInt` type classes into `Int`.
This commit is contained in:
Leonardo de Moura 2025-06-26 17:28:51 -07:00 committed by GitHub
parent 2fe6d8a70b
commit 6b520ede08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 218 additions and 30 deletions

View file

@ -89,8 +89,45 @@ theorem mul_congr.wl {α i} [ToInt α i] [_root_.Mul α] [ToInt.Mul α i] (h : i
have := i.wrap_eq_self_iff h' _ |>.mpr (ToInt.toInt_mem b)
rw [h₂] at this; rw [← this] at h₂; apply mul_congr.ww h h₁ h₂
-- TODO: add theorems for other operations
/-! Subtraction -/
theorem sub_congr {α i} [ToInt α i] [_root_.Sub α] [ToInt.Sub α i] {a b : α} {a' b' : Int}
(h₁ : toInt a = a') (h₂ : toInt b = b') : toInt (a - b) = i.wrap (a' - b') := by
rw [ToInt.Sub.toInt_sub, h₁, h₂]
/-! Negation -/
theorem neg_congr {α i} [ToInt α i] [_root_.Neg α] [ToInt.Neg α i] {a : α} {a' : Int}
(h₁ : toInt a = a') : toInt (- a) = i.wrap (- a') := by
rw [ToInt.Neg.toInt_neg, h₁]
/-! Power -/
theorem pow_congr {α i} [ToInt α i] [HPow α Nat α] [ToInt.Pow α i] {a : α} (k : Nat) (a' : Int)
(h₁ : toInt a = a') : toInt (a ^ k) = i.wrap (a' ^ k) := by
rw [ToInt.Pow.toInt_pow, h₁]
/-! Division -/
theorem div_congr {α i} [ToInt α i] [_root_.Div α] [ToInt.Div α i] {a b : α} {a' b' : Int}
(h₁ : toInt a = a') (h₂ : toInt b = b') : toInt (a / b) = a' / b' := by
rw [ToInt.Div.toInt_div, h₁, h₂]
/-! Modulo -/
theorem mod_congr {α i} [ToInt α i] [_root_.Mod α] [ToInt.Mod α i] {a b : α} {a' b' : Int}
(h₁ : toInt a = a') (h₂ : toInt b = b') : toInt (a % b) = a' % b' := by
rw [ToInt.Mod.toInt_mod, h₁, h₂]
/-! OfNat -/
theorem ofNat_eq {α i} [ToInt α i] [∀ n, _root_.OfNat α n] [ToInt.OfNat α i] (n : Nat)
: toInt (OfNat.ofNat (α := α) n) = i.wrap n := by
apply ToInt.OfNat.toInt_ofNat
/-! Zero -/
theorem zero_eq {α i} [ToInt α i] [_root_.Zero α] [ToInt.Zero α i] : toInt (0 : α) = 0 := by
apply ToInt.Zero.toInt_zero
end Lean.Grind.ToInt

View file

@ -382,6 +382,13 @@ private def internalizeNat (e : Expr) : GoalM Unit := do
let c := { p := .add (-1) x p, h := .defnNat e' x e'' : EqCnstr }
c.assert
private def isToIntForbiddenParent (parent? : Option Expr) : Bool :=
if let some parent := parent? then
getKindAndType? parent |>.isSome
else
false
/--
Internalizes an integer (and `Nat`) expression. Here are the different cases that are handled.
@ -394,14 +401,16 @@ Internalizes an integer (and `Nat`) expression. Here are the different cases tha
def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
unless (← getConfig).cutsat do return ()
let some (k, type) := getKindAndType? e | return ()
if isForbiddenParent parent? k then return ()
trace[grind.debug.cutsat.internalize] "{e} : {type}"
if type.isConstOf ``Int then
if isForbiddenParent parent? k then return ()
trace[grind.debug.cutsat.internalize] "{e} : {type}"
match k with
| .div => propagateDiv e
| .mod => propagateMod e
| _ => internalizeInt e
else if type.isConstOf ``Nat then
if isForbiddenParent parent? k then return ()
trace[grind.debug.cutsat.internalize] "{e} : {type}"
if (← hasForeignVar e) then return ()
discard <| mkForeignVar e .nat
match k with
@ -410,6 +419,8 @@ def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
| .toNat => propagateToNat e
| _ => internalizeNat e
else if let some (e', h) ← toInt? e type then
if isToIntForbiddenParent parent? then return ()
trace[grind.debug.cutsat.internalize] "{e} : {type}"
-- TODO: save `(e', h)`
trace[grind.debug.cutsat.toInt] "{e} ==> {e'}"
trace[grind.debug.cutsat.toInt] "{h} : {← inferType h}"

View file

@ -19,16 +19,41 @@ private def checkDecl (declName : Name) : MetaM Unit := do
unless (← getEnv).contains declName do
throwMissingDecl declName
private def mkOpCongr (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) (range : Grind.IntInterval) (opBaseName : Name) (thmName : Name) : MetaM ToIntCongr := do
let op := mkApp (mkConst opBaseName [u]) type
let .some opInst ← trySynthInstance op | return {}
let toIntOpName := ``Grind.ToInt ++ opBaseName
private def mkOfNatThm? (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) : MetaM (Option Expr) := do
-- ∀ n, OfNat α n
let ofNat := mkForall `n .default (mkConst ``Nat) (mkApp2 (mkConst ``OfNat [u]) type (mkBVar 0))
let .some ofNatInst ← trySynthInstance ofNat
| reportMissingToIntAdapter type ofNat; return none
let toIntOfNat := mkApp4 (mkConst ``Grind.ToInt.OfNat [u]) type ofNatInst rangeExpr toIntInst
let .some toIntOfNatInst ← trySynthInstance toIntOfNat
| reportMissingToIntAdapter type toIntOfNat; return none
return mkApp5 (mkConst ``Grind.ToInt.ofNat_eq [u]) type rangeExpr toIntInst ofNatInst toIntOfNatInst
/-- Helper function for `mkSimpleOpThm?` and `mkPowThm?` -/
private def mkSimpleOpThmCore? (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) (op : Expr) (opSuffix : Name) (thmName : Name) : MetaM (Option Expr) := do
let .some opInst ← trySynthInstance op | return none
let toIntOpName := ``Grind.ToInt ++ opSuffix
checkDecl toIntOpName
let toIntOp := mkApp4 (mkConst toIntOpName [u]) type opInst rangeExpr toIntInst
let .some toIntOpInst ← trySynthInstance toIntOp
| reportMissingToIntAdapter type toIntOp; return {}
| reportMissingToIntAdapter type toIntOp; return none
checkDecl thmName
let c := mkApp5 (mkConst thmName [u]) type rangeExpr toIntInst opInst toIntOpInst
return mkApp5 (mkConst thmName [u]) type rangeExpr toIntInst opInst toIntOpInst
/-- Simpler version of `mkBinOpThms` for operators that have only one congruence theorem. -/
private def mkSimpleOpThm? (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) (opBaseName : Name) (thmName : Name) : MetaM (Option Expr) := do
let op := mkApp (mkConst opBaseName [u]) type
mkSimpleOpThmCore? type u toIntInst rangeExpr op opBaseName thmName
/-- Simpler version of `mkBinOpThms` for operators that have only one congruence theorem. -/
private def mkPowThm? (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) : MetaM (Option Expr) := do
let op := mkApp3 (mkConst ``HPow [u, 0, u]) type Nat.mkType type
mkSimpleOpThmCore? type u toIntInst rangeExpr op `Pow ``Grind.ToInt.pow_congr
private def mkBinOpThms (type : Expr) (u : Level) (toIntInst : Expr) (rangeExpr : Expr) (range : Grind.IntInterval) (opBaseName : Name) (thmName : Name) : MetaM ToIntThms := do
let some c ← mkSimpleOpThm? type u toIntInst rangeExpr opBaseName thmName | return {}
let opInst := c.appFn!.appArg!
let toIntOpInst := c.appArg!
let env ← getEnv
let cwwName := thmName ++ `ww
let cwlName := thmName ++ `wl
@ -118,13 +143,22 @@ where
let ofLT := mkApp5 (mkConst ``Grind.ToInt.of_lt [u]) type rangeExpr toIntInst ltInst toIntLTInst
let ofNotLT := mkApp5 (mkConst ``Grind.ToInt.of_not_lt [u]) type rangeExpr toIntInst ltInst toIntLTInst
pure (some ofLT, some ofNotLT)
let mkOp (opBaseName : Name) (thmName : Name) :=
mkOpCongr type u toIntInst rangeExpr range opBaseName thmName
let add ← mkOp ``Add ``Grind.ToInt.add_congr
let mul ← mkOp ``Mul ``Grind.ToInt.mul_congr
-- TODO: other operators
let mkBinOpThms (opBaseName : Name) (thmName : Name) :=
mkBinOpThms type u toIntInst rangeExpr range opBaseName thmName
let mkSimpleOpThm? (opBaseName : Name) (thmName : Name) :=
mkSimpleOpThm? type u toIntInst rangeExpr opBaseName thmName
let addThms ← mkBinOpThms ``Add ``Grind.ToInt.add_congr
let mulThms ← mkBinOpThms ``Mul ``Grind.ToInt.mul_congr
let subThm? ← mkSimpleOpThm? ``Sub ``Grind.ToInt.sub_congr
let negThm? ← mkSimpleOpThm? ``Neg ``Grind.ToInt.neg_congr
let divThm? ← mkSimpleOpThm? ``Div ``Grind.ToInt.div_congr
let modThm? ← mkSimpleOpThm? ``Mod ``Grind.ToInt.mod_congr
let powThm? ← mkPowThm? type u toIntInst rangeExpr
let zeroThm? ← mkSimpleOpThm? ``Zero ``Grind.ToInt.zero_eq
let ofNatThm? ← mkOfNatThm? type u toIntInst rangeExpr
return some {
type, u, toIntInst, rangeExpr, range, toInt, wrap, ofWrap0?, ofEq, ofDiseq, ofLE?, ofNotLE?, ofLT?, ofNotLT?, add, mul
type, u, toIntInst, rangeExpr, range, toInt, wrap, ofWrap0?, ofEq, ofDiseq, ofLE?, ofNotLE?, ofLT?, ofNotLT?, addThms, mulThms,
subThm?, negThm?, divThm?, modThm?, powThm?, zeroThm?, ofNatThm?
}
structure ToIntM.Context where
@ -171,34 +205,106 @@ private def expandWrap (a b : Expr) (h : Expr) : ToIntM (Expr × Expr) := do
return (b', h)
| _ => throwError "`grind cutsat`, `ToInt` interval not supported yet"
private def mkToIntResult (toIntOp : ToIntCongr) (mkBinOp : Expr → Expr → Expr) (a b : Expr) (a' b' : Expr) (h₁ h₂ : Expr) : ToIntM (Expr × Expr) := do
let f := toIntOp.c?.get!
/--
Given `h : toInt a = b`, if `b` is of the form `i.wrap b'`,
invokes `expandWrap a b' h`
-/
private def expandIfWrap (a b : Expr) (h : Expr) : ToIntM (Expr × Expr) := do
match isWrap b with
| none => return (b, h)
| some b => expandWrap a b h
private def mkWrap (a : Expr) : ToIntM Expr := do
return mkApp (← getInfo).wrap a
private def ToIntThms.mkResult (toIntThms : ToIntThms) (mkBinOp : Expr → Expr → Expr) (a b : Expr) (a' b' : Expr) (h₁ h₂ : Expr) : ToIntM (Expr × Expr) := do
let f := toIntThms.c?.get!
let mk (f : Expr) (a' b' : Expr) : ToIntM (Expr × Expr) := do
-- If the appropriate `wrap` cancellation theorem is missing, we have to expand the nested wrap.
let (a', h₁) ← expandIfWrap a a' h₁
let (b', h₂) ← expandIfWrap b b' h₂
let h := mkApp6 f a b a' b' h₁ h₂
let r := mkApp (← getInfo).wrap (mkBinOp a' b')
let r ← mkWrap (mkBinOp a' b')
return (r, h)
match isWrap a', isWrap b' with
| none, none => mk f a' b'
| some a'', none => if let some f := toIntOp.c_wl? then mk f a'' b' else mk f a' b'
| none, some b'' => if let some f := toIntOp.c_wr? then mk f a' b'' else mk f a' b'
| some a'', some b'' => if let some f := toIntOp.c_ww? then mk f a'' b'' else mk f a' b'
| some a'', none => if let some f := toIntThms.c_wl? then mk f a'' b' else mk f a' b'
| none, some b'' => if let some f := toIntThms.c_wr? then mk f a' b'' else mk f a' b'
| some a'', some b'' => if let some f := toIntThms.c_ww? then mk f a'' b'' else mk f a' b'
partial def toInt (e : Expr) : ToIntM (Expr × Expr) := do
match_expr e with
| HAdd.hAdd α β γ _ a b =>
unless isHomo α β γ do return (← toIntDef e)
toIntBin (← getInfo).add mkIntAdd a b
toIntBin (← getInfo).addThms mkIntAdd a b
| HMul.hMul α β γ _ a b =>
unless isHomo α β γ do return (← toIntDef e)
toIntBin (← getInfo).mul mkIntMul a b
-- TODO: other operators
toIntBin (← getInfo).mulThms mkIntMul a b
| HDiv.hDiv α β γ _ a b =>
unless isHomo α β γ do return (← toIntDef e)
processDivMod (isDiv := true) a b
| HMod.hMod α β γ _ a b =>
unless isHomo α β γ do return (← toIntDef e)
processDivMod (isDiv := false) a b
| HSub.hSub α β γ _ a b =>
unless isHomo α β γ do return (← toIntDef e)
processSub a b
| Neg.neg _ _ a =>
processNeg a
| HPow.hPow α β γ _ a b =>
unless isSameExpr α γ && β.isConstOf ``Nat do return (← toIntDef e)
processPow a b
| Zero.zero _ _ =>
let some thm := (← getInfo).zeroThm? | toIntDef e
return (mkIntLit 0, thm)
| OfNat.ofNat _ n _ =>
let some thm := (← getInfo).ofNatThm? | toIntDef e
let some n ← getNatValue? n | toIntDef e
let r := mkIntLit ((← getInfo).range.wrap n)
let h := mkApp thm (toExpr n)
return (r, h)
| _ => toIntDef e
where
toIntBin (toIntOp : ToIntCongr) (mkBinOp : Expr → Expr → Expr) (a b : Expr) : ToIntM (Expr × Expr) := do
toIntBin (toIntOp : ToIntThms) (mkBinOp : Expr → Expr → Expr) (a b : Expr) : ToIntM (Expr × Expr) := do
unless toIntOp.c?.isSome do return (← toIntDef e)
let (a', h₁) ← toInt a
let (b', h₂) ← toInt b
mkToIntResult toIntOp mkBinOp a b a' b' h₁ h₂
toIntOp.mkResult mkBinOp a b a' b' h₁ h₂
toIntAndExpandWrap (a : Expr) : ToIntM (Expr × Expr) := do
let (a', h₁) ← toInt a
expandIfWrap a a' h₁
processDivMod (isDiv : Bool) (a b : Expr) : ToIntM (Expr × Expr) := do
let some thm ← if isDiv then pure (← getInfo).divThm? else pure (← getInfo).modThm?
| return (← toIntDef e)
let (a', h₁) ← toIntAndExpandWrap a
let (b', h₂) ← toIntAndExpandWrap b
let r := if isDiv then mkIntDiv a' b' else mkIntMod a' b'
let h := mkApp6 thm a b a' b' h₁ h₂
return (r, h)
processSub (a b : Expr) : ToIntM (Expr × Expr) := do
let some thm := (← getInfo).subThm? | return (← toIntDef e)
let (a', h₁) ← toIntAndExpandWrap a
let (b', h₂) ← toIntAndExpandWrap b
let r ← mkWrap (mkIntSub a' b')
let h := mkApp6 thm a b a' b' h₁ h₂
return (r, h)
processNeg (a : Expr) : ToIntM (Expr × Expr) := do
let some thm := (← getInfo).negThm? | return (← toIntDef e)
let (a', h₁) ← toIntAndExpandWrap a
let r ← mkWrap (mkIntNeg a')
let h := mkApp3 thm a a' h₁
return (r, h)
processPow (a b : Expr) : ToIntM (Expr × Expr) := do
let some thm := (← getInfo).powThm? | return (← toIntDef e)
let (a', h₁) ← toIntAndExpandWrap a
let r ← mkWrap (mkIntPowNat a' b)
let h := mkApp4 thm a b a' h₁
return (r, h)
def toInt? (a : Expr) (type : Expr) : GoalM (Option (Expr × Expr)) := do
ToIntM.run? type do

View file

@ -10,10 +10,38 @@ import Lean.Meta.Tactic.Grind.Arith.Util
namespace Lean.Meta.Grind.Arith.Cutsat
open Lean Grind
structure ToIntCongr where
/--
Theorems for operators that have support for `i.wrap` over `i.wrap` simplification.
Currently only addition and multiplication have `wrap` cancellation theorems
-/
structure ToIntThms where
/--
Basic theorem of the form
```
toInt a = a' → toInt b = b' → toInt (a ⊞ b) = i.wrap (a' ⊞ b')`
```
-/
c? : Option Expr := none
/--
Left-right `wrap` cancellation theorem of the form
```
toInt a = i.wrap a' → toInt b = i.wrap b' → toInt (a ⊞ b) = i.wrap (a' ⊞ b')
```
-/
c_ww? : Option Expr := none
/--
Left `wrap` cancellation theorem of the form
```
toInt a = i.wrap a' → toInt b = b' → toInt (a ⊞ b) = i.wrap (a' ⊞ b')
```
-/
c_wl? : Option Expr := none
/--
Right `wrap` cancellation theorem of the form
```
toInt a = a' → toInt b = i.wrap b' → toInt (a ⊞ b) = i.wrap (a' ⊞ b')
```
-/
c_wr? : Option Expr := none
structure ToIntInfo where
@ -32,8 +60,14 @@ structure ToIntInfo where
ofNotLE? : Option Expr
ofLT? : Option Expr
ofNotLT? : Option Expr
add : ToIntCongr
mul : ToIntCongr
-- TODO: other operators
addThms : ToIntThms
mulThms : ToIntThms
subThm? : Option Expr
negThm? : Option Expr
divThm? : Option Expr
modThm? : Option Expr
powThm? : Option Expr
zeroThm? : Option Expr
ofNatThm? : Option Expr
end Lean.Meta.Grind.Arith.Cutsat