diff --git a/src/Init/Grind/ToInt.lean b/src/Init/Grind/ToInt.lean index 651dd7ba70..258d641e6a 100644 --- a/src/Init/Grind/ToInt.lean +++ b/src/Init/Grind/ToInt.lean @@ -51,7 +51,7 @@ abbrev uint (n : Nat) := IntInterval.co 0 (2 ^ n) abbrev sint (n : Nat) := IntInterval.co (-(2 ^ (n - 1))) (2 ^ (n - 1)) /-- The lower bound of the interval, if finite. -/ -def lo? (i : IntInterval) : Option Int := +@[expose] def lo? (i : IntInterval) : Option Int := match i with | co lo _ => some lo | ci lo => some lo @@ -59,23 +59,21 @@ def lo? (i : IntInterval) : Option Int := | ii => none /-- The upper bound of the interval, if finite. -/ -def hi? (i : IntInterval) : Option Int := +@[expose] def hi? (i : IntInterval) : Option Int := match i with | co _ hi => some hi | ci _ => none | io hi => some hi | ii => none -@[simp] -def nonEmpty (i : IntInterval) : Bool := +@[simp, expose] def nonEmpty (i : IntInterval) : Bool := match i with | co lo hi => lo < hi | ci _ => true | io _ => true | ii => true -@[simp] -def isFinite (i : IntInterval) : Bool := +@[simp, expose] def isFinite (i : IntInterval) : Bool := match i with | co _ _ => true | ci _ diff --git a/src/Init/Grind/ToIntLemmas.lean b/src/Init/Grind/ToIntLemmas.lean index 07db2a4e7a..e1fa458b46 100644 --- a/src/Init/Grind/ToIntLemmas.lean +++ b/src/Init/Grind/ToIntLemmas.lean @@ -14,7 +14,7 @@ namespace Lean.Grind.ToInt /-! Wrap -/ -theorem of_eq_wrap_co_0 (i : IntInterval) (hi : Int) (h : i == .co 0 hi) {a b : Int} : a = i.wrap b → a = b % hi := by +theorem of_eq_wrap_co_0 (i : IntInterval) (hi : Int) (h : i = .co 0 hi) {a b : Int} : a = i.wrap b → a = b % hi := by revert h cases i <;> simp intro h₁ h₂; subst h₁ h₂; simp @@ -154,4 +154,14 @@ theorem le_upper {α i} [ToInt α i] (hi' : Int) (h : i.hi? == some (-hi' + 1)) have h' := ToInt.toInt_mem a revert h h'; cases i <;> simp [IntInterval.hi?] <;> intro h <;> subst h <;> intros <;> omega +theorem ge_lower' {α i} [ToInt α i] (lo : Int) (h : i.lo? = some lo) (a : α) : lo ≤ toInt a := by + have h' := ToInt.toInt_mem a + revert h h'; cases i <;> simp [IntInterval.lo?] <;> intro h <;> subst h <;> intros <;> assumption + +theorem le_upper' {α i} [ToInt α i] (hi : Int) (h : i.hi? = some hi) (a : α) : toInt a + 1 ≤ hi := by + have h' := ToInt.toInt_mem a + revert h h'; cases i <;> simp [IntInterval.hi?] <;> intro h <;> subst h <;> intros + next h => exact Int.add_one_le_of_lt h + next h => exact Int.add_one_le_of_lt h + end Lean.Grind.ToInt diff --git a/src/Init/MetaTypes.lean b/src/Init/MetaTypes.lean index 7152608779..3d53eacc8c 100644 --- a/src/Init/MetaTypes.lean +++ b/src/Init/MetaTypes.lean @@ -286,6 +286,10 @@ structure Config where bitvector literals. -/ bitVecOfNat : Bool := true + /-- + When `true` (default: `true`), the `^` simprocs generate an warning it the exponents are too big. + -/ + warnExponents : Bool := true deriving Inhabited, BEq -- Configuration object for `simp_all` diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToInt.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToInt.lean index 7a895d23ed..f1c84ee5a5 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToInt.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToInt.lean @@ -11,7 +11,7 @@ public import Lean.Meta.Tactic.Grind.SynthInstance public import Lean.Meta.Tactic.Grind.Simp public import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util public import Lean.Meta.Tactic.Grind.Arith.EvalNum - +public import Lean.Meta.Tactic.Grind.Arith.Cutsat.Norm public section namespace Lean.Meta.Grind.Arith.Cutsat @@ -26,6 +26,25 @@ private def checkDecl (declName : Name) : MetaM Unit := do unless (← getEnv).contains declName do throwMissingDecl declName +private def normalizeBound (bound : Expr) : GrindM SymbolicBound := do + if let some bound ← evalInt? bound then + return { val := mkIntLit bound, ival? := some bound } + else + return { val := bound, ival? := none } + +def SymbolicBound.isZero (b : SymbolicBound) : Bool := + if let some b := b.ival? then + b == 0 + else + false + +/-- Given a symbolic bound `b`, returns `-b + 1` -/ +def SymbolicBound.mkIntNegSucc (b : SymbolicBound) : MetaM Expr := do + if let some val := b.ival? then + return mkIntLit (-val + 1) + else + return mkIntAdd (mkIntNeg b.val) (mkIntLit 1) + def getToIntId? (type : Expr) : GoalM (Option Nat) := do if let some id? := (← get').toIntIds.find? { expr := type } then return id? @@ -35,16 +54,12 @@ def getToIntId? (type : Expr) : GoalM (Option Nat) := do toIntIds := s.toIntIds.insert { expr := type } id? } return id? where - toIntInterval? (rangeExpr : Expr) : GoalM (Option Grind.IntInterval) := do + toIntInterval? (rangeExpr : Expr) : GoalM (Option SymbolicIntInterval) := do let rangeExpr ← whnfD rangeExpr match_expr rangeExpr with | Grind.IntInterval.co lo hi => - let some lo ← evalInt? lo - | trace[grind.debug.cutsat.toInt] "`ToInt` lower bound could not be reduced to an integer{indentExpr (← whnfD lo)}\nfor type{indentExpr type}" - return none - let some hi ← evalInt? hi - | trace[grind.debug.cutsat.toInt] "`ToInt` upper bound could not be reduced to an integer{indentExpr hi}\nfor type{indentExpr type}" - return none + let lo ← normalizeBound lo + let hi ← normalizeBound hi return some (.co lo hi) | _ => trace[grind.debug.cutsat.toInt] "unsupported `ToInt` interval{indentExpr rangeExpr}\nfor type{indentExpr type}" @@ -62,21 +77,32 @@ where let some range ← toIntInterval? rangeExpr | return none let toInt := mkApp3 (mkConst ``Grind.ToInt.toInt [u]) type rangeExpr toIntInst let wrap := mkApp (mkConst ``Grind.IntInterval.wrap) rangeExpr - let ofWrap0? := if let .co 0 hi := range then - some <| mkApp3 (mkConst ``Grind.ToInt.of_eq_wrap_co_0) rangeExpr (toExpr hi) eagerReflBoolTrue - else - none + let ofWrap0? ← if let .co lo hi := range then + if lo.isZero then + pure <| some <| mkApp3 (mkConst ``Grind.ToInt.of_eq_wrap_co_0) rangeExpr hi.val (← mkEqRefl rangeExpr) + else pure none + else pure none let ofEq := mkApp3 (mkConst ``Grind.ToInt.of_eq [u]) type rangeExpr toIntInst let ofDiseq := mkApp3 (mkConst ``Grind.ToInt.of_diseq [u]) type rangeExpr toIntInst - let lowerThm? := if let some lo := range.lo? then - if lo == 0 then - some <| mkApp4 (mkConst ``Grind.ToInt.ge_lower0 [u]) type rangeExpr toIntInst eagerReflBoolTrue + let lowerThm? ← if let some lo := range.lo? then + if let some lo' := lo.ival? then + if lo' == 0 then + pure <| some <| mkApp4 (mkConst ``Grind.ToInt.ge_lower0 [u]) type rangeExpr toIntInst eagerReflBoolTrue + else + pure <| some <| mkApp5 (mkConst ``Grind.ToInt.ge_lower [u]) type rangeExpr toIntInst lo.val eagerReflBoolTrue else - some <| mkApp5 (mkConst ``Grind.ToInt.ge_lower [u]) type rangeExpr toIntInst (toExpr lo) eagerReflBoolTrue - else none - let upperThm? := if let some hi := range.hi? then - some <| mkApp5 (mkConst ``Grind.ToInt.le_upper [u]) type rangeExpr toIntInst (toExpr (-hi + 1)) eagerReflBoolTrue - else none + -- Symbolic case + let some_lo ← mkSome Int.mkType lo.val + pure <| some <| mkApp5 (mkConst ``Grind.ToInt.ge_lower' [u]) type rangeExpr toIntInst lo.val (← mkEqRefl some_lo) + else pure none + let upperThm? ← if let some hi := range.hi? then + if hi.isNumeral then + pure <| some <| mkApp5 (mkConst ``Grind.ToInt.le_upper [u]) type rangeExpr toIntInst (← hi.mkIntNegSucc) eagerReflBoolTrue + else + -- Symbolic case + let some_hi ← mkSome Int.mkType hi.val + pure <| some <| mkApp5 (mkConst ``Grind.ToInt.le_upper' [u]) type rangeExpr toIntInst hi.val (← mkEqRefl some_hi) + else pure none trace[grind.debug.cutsat.toInt] "registered toInt: {type}" let id := (← get').toIntInfos.size modify' fun s => { s with toIntInfos := s.toIntInfos.push { id, type, u, toIntInst, rangeExpr, range, toInt, wrap, ofWrap0?, ofEq, ofDiseq, lowerThm?, upperThm? } } @@ -292,21 +318,46 @@ private def isWrap (e : Expr) : Option Expr := | Grind.IntInterval.wrap _ a => some a | _ => none +private def hasNumericLoHi : ToIntM Bool := do + let info ← getInfo + let some lo := info.range.lo? | return false + let some hi := info.range.hi? | return false + return lo.isNumeral && hi.isNumeral + /-- Given `h : toInt a = i.wrap b`, return `(b', h)` where `b'` is the expanded form of `i.wrap b`, and `h : toInt a = b'` -/ private def expandWrap (a b : Expr) (h : Expr) : ToIntM (Expr × Expr) := do - match (← getInfo).range with + let range := (← getInfo).range + match range with | .ii => return (b, h) - | .co 0 hi => - let b' := mkIntMod b (toExpr hi) - let toA := mkApp (← getInfo).toInt a - let h := mkApp3 (← getInfo).ofWrap0?.get! toA b h - return (b', h) | .co lo hi => - let b' := mkIntAdd (mkIntMod (mkIntSub b (toExpr lo)) (toExpr (hi - lo))) (toExpr lo) - return (b', h) + if lo.isZero then + let b' := mkIntMod b hi.val + let toA := mkApp (← getInfo).toInt a + let h := mkApp3 (← getInfo).ofWrap0?.get! toA b h + if hi.isNumeral then + return (b', h) + else + -- We must preprocess `b'` because `hi` has not been normalized and may interact with `%` + let r ← preprocess b' + let h ← mkEqTrans h (← r.getProof) + let b' := r.expr + internalize b' (← getGeneration b) + return (b', h) + else + let b' ← range.wrap b + if (← hasNumericLoHi) then + return (b', h) + else + -- We must preprocess `b'` because `lo` and/or `hi` are symbolic values that may + -- interact with the wrap operations and have not been normalized yet. + let r ← preprocess b' + let h ← mkEqTrans h (← r.getProof) + let b' := r.expr + internalize b' (← getGeneration b) + return (b', h) | _ => throwError "`grind cutsat`, `ToInt` interval not supported yet" /-- @@ -364,9 +415,12 @@ private partial def toInt' (e : Expr) : ToIntM (Expr × Expr) := do | OfNat.ofNat _ n _ => let some thm ← getOfNatThm? | mkToIntVar e let some n ← getNatValue? n | mkToIntVar e - let r := mkIntLit ((← getInfo).range.wrap n) let h := mkApp thm (toExpr n) - return (r, h) + if (← hasNumericLoHi) then + let r ← (← getInfo).range.wrap (mkIntLit n) + return (r, h) + else + expandWrap e (mkIntLit n) h | _ => mkToIntVar e where toIntBin (toIntOp : ToIntThms) (mkBinOp : Expr → Expr → Expr) (a b : Expr) : ToIntM (Expr × Expr) := do @@ -441,13 +495,19 @@ def assertToIntBounds (e : Expr) (x : Var) : GoalM Unit := do let i := info.range if let some lo := i.lo? then let some thm := info.lowerThm? | unreachable! - let p := .add (-1) x (.num lo) - let c := { p, h := .bound (mkApp thm a) : LeCnstr } - c.assert + if let some lo := lo.ival? then + let p := .add (-1) x (.num lo) + let c := { p, h := .bound (mkApp thm a) : LeCnstr } + c.assert + else + pushNewFact <| mkApp thm a if let some hi := i.hi? then let some thm := info.upperThm? | unreachable! - let p := .add 1 x (.num (-hi + 1)) - let c := { p, h := .bound (mkApp thm a) : LeCnstr } - c.assert + if let some hi := hi.ival? then + let p := .add 1 x (.num (-hi + 1)) + let c := { p, h := .bound (mkApp thm a) : LeCnstr } + c.assert + else + pushNewFact <| mkApp thm a end Lean.Meta.Grind.Arith.Cutsat diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToIntInfo.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToIntInfo.lean index c5b722aedb..b56d928975 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToIntInfo.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ToIntInfo.lean @@ -49,13 +49,63 @@ structure ToIntThms where c_wr? : Option Expr := none deriving Inhabited +structure SymbolicBound where + val : Expr + -- cached int value if `val` is a numeric + ival? : Option Int + deriving Inhabited + +def SymbolicBound.isNumeral (b : SymbolicBound) : Bool := + b.ival?.isSome + +/-- Similar to `IntInterval`, but with symbolic bounds. -/ +inductive SymbolicIntInterval : Type where + | co (lo hi : SymbolicBound) + | ci (lo : SymbolicBound) + | io (hi : SymbolicBound) + | ii + deriving Inhabited + +def SymbolicIntInterval.isFinite (i : SymbolicIntInterval) : Bool := + match i with + | .co _ _ => true + | .ci _ | .io _ | .ii => false + +def SymbolicIntInterval.lo? (i : SymbolicIntInterval) : Option SymbolicBound := + match i with + | .co lo _ | .ci lo => some lo + | .io _ | .ii => none + +def SymbolicIntInterval.hi? (i : SymbolicIntInterval) : Option SymbolicBound := + match i with + | .co _ hi | .io hi => some hi + | .ci _ | .ii => none + +def SymbolicIntInterval.wrap (i : SymbolicIntInterval) (x : Expr) : MetaM Expr := do + match i with + | .co lo hi => + if let some lo' := lo.ival? then + if let some hi' := hi.ival? then + if let some x ← getIntValue? x then + return mkIntLit ((x - lo') % (hi' - lo') + lo') + else if lo' == 0 then + return mkIntMod x hi.val + else + return mkIntAdd (mkIntMod (mkIntSub x (mkIntLit lo')) (mkIntLit (hi' - lo'))) (mkIntLit lo') + if lo' == 0 then + return mkIntMod x hi.val + return mkIntAdd (mkIntMod (mkIntSub x lo.val) (mkIntSub hi.val lo.val)) lo.val + | .ci _ => throwError "`grind` internal error, `.ci` interval support has not been implemented yet" + | .io _ => throwError "`grind` internal error, `.io` interval support has not been implemented yet" + | .ii => return x + structure ToIntInfo where id : Nat type : Expr u : Level toIntInst : Expr rangeExpr : Expr - range : IntInterval + range : SymbolicIntInterval toInt : Expr wrap : Expr -- theorem `of_eq_wrap_co_0` if `range == .co 0 hi` diff --git a/src/Lean/Meta/Tactic/Grind/SimpUtil.lean b/src/Lean/Meta/Tactic/Grind/SimpUtil.lean index e8ce932785..7bd7182bc0 100644 --- a/src/Lean/Meta/Tactic/Grind/SimpUtil.lean +++ b/src/Lean/Meta/Tactic/Grind/SimpUtil.lean @@ -187,11 +187,13 @@ protected def getSimpContext (config : Grind.Config) : MetaM Simp.Context := do thms ← addDeclToUnfold thms ``Ne Simp.mkContext (config := - { arith := true, zeta := config.zeta, - zetaDelta := config.zetaDelta, + { arith := true + zeta := config.zeta + zetaDelta := config.zetaDelta -- Use `OfNat.ofNat` and `Neg.neg` for representing bitvec literals - bitVecOfNat := false, - catchRuntime := false, + bitVecOfNat := false + catchRuntime := false + warnExponents := false -- `implicitDefEqProofs := true` a recurrent source of performance problems in the kernel implicitDefEqProofs := false }) (simpTheorems := #[thms]) diff --git a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean index 76edc8082b..859bc08154 100644 --- a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean +++ b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean @@ -82,7 +82,8 @@ builtin_dsimproc [simp, seval] reducePow ((_ : Int) ^ (_ : Nat)) := fun e => do let_expr HPow.hPow _ _ _ _ a b ← e | return .continue let some v₁ ← fromExpr? a | return .continue let some v₂ ← Nat.fromExpr? b | return .continue - unless (← checkExponent v₂) do return .continue + let warning := (← Simp.getConfig).warnExponents + unless (← checkExponent v₂ (warning := warning)) do return .continue return .done <| toExpr (v₁ ^ v₂) builtin_simproc [simp, seval] reduceLT (( _ : Int) < _) := reduceBinPred ``LT.lt 4 (. < .) diff --git a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean index 8714db7da7..addbc3c552 100644 --- a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean +++ b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean @@ -62,7 +62,8 @@ builtin_dsimproc [simp, seval] reducePow ((_ ^ _ : Nat)) := fun e => do let_expr HPow.hPow _ _ _ _ n m := e | return .continue let some n ← fromExpr? n | return .continue let some m ← fromExpr? m | return .continue - unless (← checkExponent m) do return .continue + let warning := (← Simp.getConfig).warnExponents + unless (← checkExponent m (warning := warning)) do return .continue return .done <| toExpr (n ^ m) builtin_dsimproc [simp, seval] reduceAnd ((_ &&& _ : Nat)) := reduceBin ``HAnd.hAnd 6 (· &&& ·) diff --git a/src/Lean/Util/SafeExponentiation.lean b/src/Lean/Util/SafeExponentiation.lean index 95a23f1deb..476983a590 100644 --- a/src/Lean/Util/SafeExponentiation.lean +++ b/src/Lean/Util/SafeExponentiation.lean @@ -26,10 +26,10 @@ Returns `true` if `n` is `≤ exponentiation.threshold`. Otherwise, reports a warning and returns `false`. This method ensures there is at most one warning message of this kind in the message log. -/ -def checkExponent (n : Nat) : CoreM Bool := do +def checkExponent (n : Nat) (warning := true) : CoreM Bool := do let threshold := exponentiation.threshold.get (← getOptions) if n > threshold then - if (← logMessageKind `unsafe.exponentiation) then + if (← pure warning <&&> logMessageKind `unsafe.exponentiation) then logWarning s!"exponent {n} exceeds the threshold {threshold}, exponentiation operation was not evaluated, use `set_option {exponentiation.threshold.name} ` to set a new threshold" return false else diff --git a/tests/lean/run/grind_cutsat_toint_1.lean b/tests/lean/run/grind_cutsat_toint_1.lean index 3432ff08ee..6d982d83ea 100644 --- a/tests/lean/run/grind_cutsat_toint_1.lean +++ b/tests/lean/run/grind_cutsat_toint_1.lean @@ -94,3 +94,22 @@ example (a b : Fin 3) : a > 0 → a ≠ b → a + b ≠ 0 → a + b ≠ 1 → Fa set_option trace.grind.debug.ring.basis true in example (a b : Fin 3) : a > 0 → a ≠ b → a + b ≠ 0 → a + b ≠ 1 → False := by grind + +example (p : Nat) (heq : p = 0) (n : Fin (p + 1)) : n = 0 := by + grind + +example (p : Nat) (heq : p = 1) (n : Fin (p + 1)) : n = 0 ∨ n = 1 := by + grind + +example (p d : Nat) (n : Fin (p + 1)) : 2 ≤ p → p ≤ d + 1 → d = 1 → n = 0 ∨ n = 1 ∨ n = 2 := by + grind + +example (s : Nat) + (i j : Fin (s + 1)) (hn : i ≠ j) (hl : ¬i < j) : j < i := by + grind + +example {n : Nat} (j : Fin (n + 1)) : j ≤ j := by + grind + +example {n : Nat} (x y : Fin ((n + 1) + 1)) (h₂ : ¬x = y) (h : ¬x < y) : y < x := by + grind