feat: improve grind cutsat support for Fin n when n is not a numeral (#10022)

This PR improves support for `Fin n` in `grind cutsat` when `n` is not a
numeral. For example, the following goals can now be solved
automatically:

```lean
example (p d : Nat) (n : Fin (p + 1)) 
    : 2 ≤ p → p ≤ d + 1 → d = 1 → n = 0 ∨ n = 1 ∨ n = 2 := by
  grind

example (s : Nat) (i j : Fin (s + 1)) (hn : i ≠ j) (hl : ¬i < j) : j < i := by
  grind

example {n : Nat} (j : Fin (n + 1)) : j ≤ j := by
  grind

example {n : Nat} (x y : Fin ((n + 1) + 1)) (h₂ : ¬x = y) (h : ¬x < y) : y < x := by
  grind
```
This commit is contained in:
Leonardo de Moura 2025-08-21 10:25:52 -07:00 committed by GitHub
parent d9a73dd1e3
commit 0db795a1dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 197 additions and 52 deletions

View file

@ -51,7 +51,7 @@ abbrev uint (n : Nat) := IntInterval.co 0 (2 ^ n)
abbrev sint (n : Nat) := IntInterval.co (-(2 ^ (n - 1))) (2 ^ (n - 1))
/-- The lower bound of the interval, if finite. -/
def lo? (i : IntInterval) : Option Int :=
@[expose] def lo? (i : IntInterval) : Option Int :=
match i with
| co lo _ => some lo
| ci lo => some lo
@ -59,23 +59,21 @@ def lo? (i : IntInterval) : Option Int :=
| ii => none
/-- The upper bound of the interval, if finite. -/
def hi? (i : IntInterval) : Option Int :=
@[expose] def hi? (i : IntInterval) : Option Int :=
match i with
| co _ hi => some hi
| ci _ => none
| io hi => some hi
| ii => none
@[simp]
def nonEmpty (i : IntInterval) : Bool :=
@[simp, expose] def nonEmpty (i : IntInterval) : Bool :=
match i with
| co lo hi => lo < hi
| ci _ => true
| io _ => true
| ii => true
@[simp]
def isFinite (i : IntInterval) : Bool :=
@[simp, expose] def isFinite (i : IntInterval) : Bool :=
match i with
| co _ _ => true
| ci _

View file

@ -14,7 +14,7 @@ namespace Lean.Grind.ToInt
/-! Wrap -/
theorem of_eq_wrap_co_0 (i : IntInterval) (hi : Int) (h : i == .co 0 hi) {a b : Int} : a = i.wrap b → a = b % hi := by
theorem of_eq_wrap_co_0 (i : IntInterval) (hi : Int) (h : i = .co 0 hi) {a b : Int} : a = i.wrap b → a = b % hi := by
revert h
cases i <;> simp
intro h₁ h₂; subst h₁ h₂; simp
@ -154,4 +154,14 @@ theorem le_upper {α i} [ToInt α i] (hi' : Int) (h : i.hi? == some (-hi' + 1))
have h' := ToInt.toInt_mem a
revert h h'; cases i <;> simp [IntInterval.hi?] <;> intro h <;> subst h <;> intros <;> omega
theorem ge_lower' {α i} [ToInt α i] (lo : Int) (h : i.lo? = some lo) (a : α) : lo ≤ toInt a := by
have h' := ToInt.toInt_mem a
revert h h'; cases i <;> simp [IntInterval.lo?] <;> intro h <;> subst h <;> intros <;> assumption
theorem le_upper' {α i} [ToInt α i] (hi : Int) (h : i.hi? = some hi) (a : α) : toInt a + 1 ≤ hi := by
have h' := ToInt.toInt_mem a
revert h h'; cases i <;> simp [IntInterval.hi?] <;> intro h <;> subst h <;> intros
next h => exact Int.add_one_le_of_lt h
next h => exact Int.add_one_le_of_lt h
end Lean.Grind.ToInt

View file

@ -286,6 +286,10 @@ structure Config where
bitvector literals.
-/
bitVecOfNat : Bool := true
/--
When `true` (default: `true`), the `^` simprocs generate an warning it the exponents are too big.
-/
warnExponents : Bool := true
deriving Inhabited, BEq
-- Configuration object for `simp_all`

View file

@ -11,7 +11,7 @@ public import Lean.Meta.Tactic.Grind.SynthInstance
public import Lean.Meta.Tactic.Grind.Simp
public import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
public import Lean.Meta.Tactic.Grind.Arith.EvalNum
public import Lean.Meta.Tactic.Grind.Arith.Cutsat.Norm
public section
namespace Lean.Meta.Grind.Arith.Cutsat
@ -26,6 +26,25 @@ private def checkDecl (declName : Name) : MetaM Unit := do
unless (← getEnv).contains declName do
throwMissingDecl declName
private def normalizeBound (bound : Expr) : GrindM SymbolicBound := do
if let some bound ← evalInt? bound then
return { val := mkIntLit bound, ival? := some bound }
else
return { val := bound, ival? := none }
def SymbolicBound.isZero (b : SymbolicBound) : Bool :=
if let some b := b.ival? then
b == 0
else
false
/-- Given a symbolic bound `b`, returns `-b + 1` -/
def SymbolicBound.mkIntNegSucc (b : SymbolicBound) : MetaM Expr := do
if let some val := b.ival? then
return mkIntLit (-val + 1)
else
return mkIntAdd (mkIntNeg b.val) (mkIntLit 1)
def getToIntId? (type : Expr) : GoalM (Option Nat) := do
if let some id? := (← get').toIntIds.find? { expr := type } then
return id?
@ -35,16 +54,12 @@ def getToIntId? (type : Expr) : GoalM (Option Nat) := do
toIntIds := s.toIntIds.insert { expr := type } id? }
return id?
where
toIntInterval? (rangeExpr : Expr) : GoalM (Option Grind.IntInterval) := do
toIntInterval? (rangeExpr : Expr) : GoalM (Option SymbolicIntInterval) := do
let rangeExpr ← whnfD rangeExpr
match_expr rangeExpr with
| Grind.IntInterval.co lo hi =>
let some lo ← evalInt? lo
| trace[grind.debug.cutsat.toInt] "`ToInt` lower bound could not be reduced to an integer{indentExpr (← whnfD lo)}\nfor type{indentExpr type}"
return none
let some hi ← evalInt? hi
| trace[grind.debug.cutsat.toInt] "`ToInt` upper bound could not be reduced to an integer{indentExpr hi}\nfor type{indentExpr type}"
return none
let lo ← normalizeBound lo
let hi ← normalizeBound hi
return some (.co lo hi)
| _ =>
trace[grind.debug.cutsat.toInt] "unsupported `ToInt` interval{indentExpr rangeExpr}\nfor type{indentExpr type}"
@ -62,21 +77,32 @@ where
let some range ← toIntInterval? rangeExpr | return none
let toInt := mkApp3 (mkConst ``Grind.ToInt.toInt [u]) type rangeExpr toIntInst
let wrap := mkApp (mkConst ``Grind.IntInterval.wrap) rangeExpr
let ofWrap0? := if let .co 0 hi := range then
some <| mkApp3 (mkConst ``Grind.ToInt.of_eq_wrap_co_0) rangeExpr (toExpr hi) eagerReflBoolTrue
else
none
let ofWrap0? ← if let .co lo hi := range then
if lo.isZero then
pure <| some <| mkApp3 (mkConst ``Grind.ToInt.of_eq_wrap_co_0) rangeExpr hi.val (← mkEqRefl rangeExpr)
else pure none
else pure none
let ofEq := mkApp3 (mkConst ``Grind.ToInt.of_eq [u]) type rangeExpr toIntInst
let ofDiseq := mkApp3 (mkConst ``Grind.ToInt.of_diseq [u]) type rangeExpr toIntInst
let lowerThm? := if let some lo := range.lo? then
if lo == 0 then
some <| mkApp4 (mkConst ``Grind.ToInt.ge_lower0 [u]) type rangeExpr toIntInst eagerReflBoolTrue
let lowerThm? ← if let some lo := range.lo? then
if let some lo' := lo.ival? then
if lo' == 0 then
pure <| some <| mkApp4 (mkConst ``Grind.ToInt.ge_lower0 [u]) type rangeExpr toIntInst eagerReflBoolTrue
else
pure <| some <| mkApp5 (mkConst ``Grind.ToInt.ge_lower [u]) type rangeExpr toIntInst lo.val eagerReflBoolTrue
else
some <| mkApp5 (mkConst ``Grind.ToInt.ge_lower [u]) type rangeExpr toIntInst (toExpr lo) eagerReflBoolTrue
else none
let upperThm? := if let some hi := range.hi? then
some <| mkApp5 (mkConst ``Grind.ToInt.le_upper [u]) type rangeExpr toIntInst (toExpr (-hi + 1)) eagerReflBoolTrue
else none
-- Symbolic case
let some_lo ← mkSome Int.mkType lo.val
pure <| some <| mkApp5 (mkConst ``Grind.ToInt.ge_lower' [u]) type rangeExpr toIntInst lo.val (← mkEqRefl some_lo)
else pure none
let upperThm? ← if let some hi := range.hi? then
if hi.isNumeral then
pure <| some <| mkApp5 (mkConst ``Grind.ToInt.le_upper [u]) type rangeExpr toIntInst (← hi.mkIntNegSucc) eagerReflBoolTrue
else
-- Symbolic case
let some_hi ← mkSome Int.mkType hi.val
pure <| some <| mkApp5 (mkConst ``Grind.ToInt.le_upper' [u]) type rangeExpr toIntInst hi.val (← mkEqRefl some_hi)
else pure none
trace[grind.debug.cutsat.toInt] "registered toInt: {type}"
let id := (← get').toIntInfos.size
modify' fun s => { s with toIntInfos := s.toIntInfos.push { id, type, u, toIntInst, rangeExpr, range, toInt, wrap, ofWrap0?, ofEq, ofDiseq, lowerThm?, upperThm? } }
@ -292,21 +318,46 @@ private def isWrap (e : Expr) : Option Expr :=
| Grind.IntInterval.wrap _ a => some a
| _ => none
private def hasNumericLoHi : ToIntM Bool := do
let info ← getInfo
let some lo := info.range.lo? | return false
let some hi := info.range.hi? | return false
return lo.isNumeral && hi.isNumeral
/--
Given `h : toInt a = i.wrap b`, return `(b', h)` where
`b'` is the expanded form of `i.wrap b`, and `h : toInt a = b'`
-/
private def expandWrap (a b : Expr) (h : Expr) : ToIntM (Expr × Expr) := do
match (← getInfo).range with
let range := (← getInfo).range
match range with
| .ii => return (b, h)
| .co 0 hi =>
let b' := mkIntMod b (toExpr hi)
let toA := mkApp (← getInfo).toInt a
let h := mkApp3 (← getInfo).ofWrap0?.get! toA b h
return (b', h)
| .co lo hi =>
let b' := mkIntAdd (mkIntMod (mkIntSub b (toExpr lo)) (toExpr (hi - lo))) (toExpr lo)
return (b', h)
if lo.isZero then
let b' := mkIntMod b hi.val
let toA := mkApp (← getInfo).toInt a
let h := mkApp3 (← getInfo).ofWrap0?.get! toA b h
if hi.isNumeral then
return (b', h)
else
-- We must preprocess `b'` because `hi` has not been normalized and may interact with `%`
let r ← preprocess b'
let h ← mkEqTrans h (← r.getProof)
let b' := r.expr
internalize b' (← getGeneration b)
return (b', h)
else
let b' ← range.wrap b
if (← hasNumericLoHi) then
return (b', h)
else
-- We must preprocess `b'` because `lo` and/or `hi` are symbolic values that may
-- interact with the wrap operations and have not been normalized yet.
let r ← preprocess b'
let h ← mkEqTrans h (← r.getProof)
let b' := r.expr
internalize b' (← getGeneration b)
return (b', h)
| _ => throwError "`grind cutsat`, `ToInt` interval not supported yet"
/--
@ -364,9 +415,12 @@ private partial def toInt' (e : Expr) : ToIntM (Expr × Expr) := do
| OfNat.ofNat _ n _ =>
let some thm ← getOfNatThm? | mkToIntVar e
let some n ← getNatValue? n | mkToIntVar e
let r := mkIntLit ((← getInfo).range.wrap n)
let h := mkApp thm (toExpr n)
return (r, h)
if (← hasNumericLoHi) then
let r ← (← getInfo).range.wrap (mkIntLit n)
return (r, h)
else
expandWrap e (mkIntLit n) h
| _ => mkToIntVar e
where
toIntBin (toIntOp : ToIntThms) (mkBinOp : Expr → Expr → Expr) (a b : Expr) : ToIntM (Expr × Expr) := do
@ -441,13 +495,19 @@ def assertToIntBounds (e : Expr) (x : Var) : GoalM Unit := do
let i := info.range
if let some lo := i.lo? then
let some thm := info.lowerThm? | unreachable!
let p := .add (-1) x (.num lo)
let c := { p, h := .bound (mkApp thm a) : LeCnstr }
c.assert
if let some lo := lo.ival? then
let p := .add (-1) x (.num lo)
let c := { p, h := .bound (mkApp thm a) : LeCnstr }
c.assert
else
pushNewFact <| mkApp thm a
if let some hi := i.hi? then
let some thm := info.upperThm? | unreachable!
let p := .add 1 x (.num (-hi + 1))
let c := { p, h := .bound (mkApp thm a) : LeCnstr }
c.assert
if let some hi := hi.ival? then
let p := .add 1 x (.num (-hi + 1))
let c := { p, h := .bound (mkApp thm a) : LeCnstr }
c.assert
else
pushNewFact <| mkApp thm a
end Lean.Meta.Grind.Arith.Cutsat

View file

@ -49,13 +49,63 @@ structure ToIntThms where
c_wr? : Option Expr := none
deriving Inhabited
structure SymbolicBound where
val : Expr
-- cached int value if `val` is a numeric
ival? : Option Int
deriving Inhabited
def SymbolicBound.isNumeral (b : SymbolicBound) : Bool :=
b.ival?.isSome
/-- Similar to `IntInterval`, but with symbolic bounds. -/
inductive SymbolicIntInterval : Type where
| co (lo hi : SymbolicBound)
| ci (lo : SymbolicBound)
| io (hi : SymbolicBound)
| ii
deriving Inhabited
def SymbolicIntInterval.isFinite (i : SymbolicIntInterval) : Bool :=
match i with
| .co _ _ => true
| .ci _ | .io _ | .ii => false
def SymbolicIntInterval.lo? (i : SymbolicIntInterval) : Option SymbolicBound :=
match i with
| .co lo _ | .ci lo => some lo
| .io _ | .ii => none
def SymbolicIntInterval.hi? (i : SymbolicIntInterval) : Option SymbolicBound :=
match i with
| .co _ hi | .io hi => some hi
| .ci _ | .ii => none
def SymbolicIntInterval.wrap (i : SymbolicIntInterval) (x : Expr) : MetaM Expr := do
match i with
| .co lo hi =>
if let some lo' := lo.ival? then
if let some hi' := hi.ival? then
if let some x ← getIntValue? x then
return mkIntLit ((x - lo') % (hi' - lo') + lo')
else if lo' == 0 then
return mkIntMod x hi.val
else
return mkIntAdd (mkIntMod (mkIntSub x (mkIntLit lo')) (mkIntLit (hi' - lo'))) (mkIntLit lo')
if lo' == 0 then
return mkIntMod x hi.val
return mkIntAdd (mkIntMod (mkIntSub x lo.val) (mkIntSub hi.val lo.val)) lo.val
| .ci _ => throwError "`grind` internal error, `.ci` interval support has not been implemented yet"
| .io _ => throwError "`grind` internal error, `.io` interval support has not been implemented yet"
| .ii => return x
structure ToIntInfo where
id : Nat
type : Expr
u : Level
toIntInst : Expr
rangeExpr : Expr
range : IntInterval
range : SymbolicIntInterval
toInt : Expr
wrap : Expr
-- theorem `of_eq_wrap_co_0` if `range == .co 0 hi`

View file

@ -187,11 +187,13 @@ protected def getSimpContext (config : Grind.Config) : MetaM Simp.Context := do
thms ← addDeclToUnfold thms ``Ne
Simp.mkContext
(config :=
{ arith := true, zeta := config.zeta,
zetaDelta := config.zetaDelta,
{ arith := true
zeta := config.zeta
zetaDelta := config.zetaDelta
-- Use `OfNat.ofNat` and `Neg.neg` for representing bitvec literals
bitVecOfNat := false,
catchRuntime := false,
bitVecOfNat := false
catchRuntime := false
warnExponents := false
-- `implicitDefEqProofs := true` a recurrent source of performance problems in the kernel
implicitDefEqProofs := false })
(simpTheorems := #[thms])

View file

@ -82,7 +82,8 @@ builtin_dsimproc [simp, seval] reducePow ((_ : Int) ^ (_ : Nat)) := fun e => do
let_expr HPow.hPow _ _ _ _ a b ← e | return .continue
let some v₁ ← fromExpr? a | return .continue
let some v₂ ← Nat.fromExpr? b | return .continue
unless (← checkExponent v₂) do return .continue
let warning := (← Simp.getConfig).warnExponents
unless (← checkExponent v₂ (warning := warning)) do return .continue
return .done <| toExpr (v₁ ^ v₂)
builtin_simproc [simp, seval] reduceLT (( _ : Int) < _) := reduceBinPred ``LT.lt 4 (. < .)

View file

@ -62,7 +62,8 @@ builtin_dsimproc [simp, seval] reducePow ((_ ^ _ : Nat)) := fun e => do
let_expr HPow.hPow _ _ _ _ n m := e | return .continue
let some n ← fromExpr? n | return .continue
let some m ← fromExpr? m | return .continue
unless (← checkExponent m) do return .continue
let warning := (← Simp.getConfig).warnExponents
unless (← checkExponent m (warning := warning)) do return .continue
return .done <| toExpr (n ^ m)
builtin_dsimproc [simp, seval] reduceAnd ((_ &&& _ : Nat)) := reduceBin ``HAnd.hAnd 6 (· &&& ·)

View file

@ -26,10 +26,10 @@ Returns `true` if `n` is `≤ exponentiation.threshold`. Otherwise,
reports a warning and returns `false`.
This method ensures there is at most one warning message of this kind in the message log.
-/
def checkExponent (n : Nat) : CoreM Bool := do
def checkExponent (n : Nat) (warning := true) : CoreM Bool := do
let threshold := exponentiation.threshold.get (← getOptions)
if n > threshold then
if (← logMessageKind `unsafe.exponentiation) then
if (← pure warning <&&> logMessageKind `unsafe.exponentiation) then
logWarning s!"exponent {n} exceeds the threshold {threshold}, exponentiation operation was not evaluated, use `set_option {exponentiation.threshold.name} <num>` to set a new threshold"
return false
else

View file

@ -94,3 +94,22 @@ example (a b : Fin 3) : a > 0 → a ≠ b → a + b ≠ 0 → a + b ≠ 1 → Fa
set_option trace.grind.debug.ring.basis true in
example (a b : Fin 3) : a > 0 → a ≠ b → a + b ≠ 0 → a + b ≠ 1 → False := by
grind
example (p : Nat) (heq : p = 0) (n : Fin (p + 1)) : n = 0 := by
grind
example (p : Nat) (heq : p = 1) (n : Fin (p + 1)) : n = 0 n = 1 := by
grind
example (p d : Nat) (n : Fin (p + 1)) : 2 ≤ p → p ≤ d + 1 → d = 1 → n = 0 n = 1 n = 2 := by
grind
example (s : Nat)
(i j : Fin (s + 1)) (hn : i ≠ j) (hl : ¬i < j) : j < i := by
grind
example {n : Nat} (j : Fin (n + 1)) : j ≤ j := by
grind
example {n : Nat} (x y : Fin ((n + 1) + 1)) (h₂ : ¬x = y) (h : ¬x < y) : y < x := by
grind