From f5fd962a25114d0be4c45fce82def56f84a46a50 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 3 Jul 2024 07:12:53 +0200 Subject: [PATCH] feat: safe exponentiation (#4637) Summary: - Adds configuration option `exponentiation.threshold` - An expression `b^n` where `b` and `n` are literals is not reduced by `whnf`, `simp`, and `isDefEq` if `n > exponentiation.threshold`. Motivation: prevents system from becoming irresponsive and/or crashing without memory. TODO: improve support in the kernel. It is using a hard-coded limit for now. --- src/Lean/CoreM.lean | 12 +++++ src/Lean/Message.lean | 7 +++ src/Lean/Meta/Offset.lean | 14 ++++-- .../Meta/Tactic/Simp/BuiltinSimprocs/Int.lean | 1 + .../Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean | 9 +++- src/Lean/Meta/WHNF.lean | 10 +++- src/Lean/Util.lean | 1 + src/Lean/Util/SafeExponentiation.lean | 34 ++++++++++++++ src/kernel/type_checker.cpp | 14 +++++- src/kernel/type_checker.h | 1 + tests/lean/run/lean_nat_gcd.lean | 1 + tests/lean/run/safeExp.lean | 47 +++++++++++++++++++ 12 files changed, 144 insertions(+), 7 deletions(-) create mode 100644 src/Lean/Util/SafeExponentiation.lean create mode 100644 tests/lean/run/safeExp.lean diff --git a/src/Lean/CoreM.lean b/src/Lean/CoreM.lean index 7d23cd9d5c..75eae9b8ad 100644 --- a/src/Lean/CoreM.lean +++ b/src/Lean/CoreM.lean @@ -519,4 +519,16 @@ instance : MonadRuntimeException CoreM where @[inline] def mapCoreM [MonadControlT CoreM m] [Monad m] (f : forall {α}, CoreM α → CoreM α) {α} (x : m α) : m α := controlAt CoreM fun runInBase => f <| runInBase x +/-- +Returns `true` if the given message kind has not been reported in the message log, +and then mark it as reported. Otherwise, returns `false`. +We use this API to ensure we don't report the same kind of warning multiple times. +-/ +def reportMessageKind (kind : Name) : CoreM Bool := do + if (← get).messages.reportedKinds.contains kind then + return false + else + modify fun s => { s with messages.reportedKinds := s.messages.reportedKinds.insert kind } + return true + end Lean diff --git a/src/Lean/Message.lean b/src/Lean/Message.lean index 9cb2aecc34..edc0c6aa23 100644 --- a/src/Lean/Message.lean +++ b/src/Lean/Message.lean @@ -359,6 +359,13 @@ structure MessageLog where hadErrors : Bool := false /-- The list of messages not already reported, in insertion order. -/ unreported : PersistentArray Message := {} + /-- + Set of message kinds that have been added to the log. + For example, we have the kind `unsafe.exponentiation.warning` for warning messages associated with + the configuration option `exponentiation.threshold`. + We don't produce a warning if the kind is already in the following set. + -/ + reportedKinds : NameSet := {} deriving Inhabited namespace MessageLog diff --git a/src/Lean/Meta/Offset.lean b/src/Lean/Meta/Offset.lean index a494098222..b3afaf728a 100644 --- a/src/Lean/Meta/Offset.lean +++ b/src/Lean/Meta/Offset.lean @@ -7,6 +7,8 @@ prelude import Lean.Data.LBool import Lean.Meta.InferType import Lean.Meta.NatInstTesters +import Lean.Meta.NatInstTesters +import Lean.Util.SafeExponentiation namespace Lean.Meta @@ -29,6 +31,10 @@ partial def evalNat (e : Expr) : OptionT MetaM Nat := do | .mvar .. => visit e | _ => failure where + evalPow (b n : Expr) : OptionT MetaM Nat := do + let n ← evalNat n + guard (← checkExponent n) + return (← evalNat b) ^ n visit e := do match_expr e with | OfNat.ofNat _ n i => guard (← isInstOfNatNat i); evalNat n @@ -48,10 +54,10 @@ where | Nat.mod a b => return (← evalNat a) % (← evalNat b) | Mod.mod _ i a b => guard (← isInstModNat i); return (← evalNat a) % (← evalNat b) | HMod.hMod _ _ _ i a b => guard (← isInstHModNat i); return (← evalNat a) % (← evalNat b) - | Nat.pow a b => return (← evalNat a) ^ (← evalNat b) - | NatPow.pow _ i a b => guard (← isInstNatPowNat i); return (← evalNat a) ^ (← evalNat b) - | Pow.pow _ _ i a b => guard (← isInstPowNat i); return (← evalNat a) ^ (← evalNat b) - | HPow.hPow _ _ _ i a b => guard (← isInstHPowNat i); return (← evalNat a) ^ (← evalNat b) + | Nat.pow a b => evalPow a b + | NatPow.pow _ i a b => guard (← isInstNatPowNat i); evalPow a b + | Pow.pow _ _ i a b => guard (← isInstPowNat i); evalPow a b + | HPow.hPow _ _ _ i a b => guard (← isInstHPowNat i); evalPow a b | _ => failure /-- diff --git a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean index 3939a4b49e..930ca1e4a6 100644 --- a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean +++ b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean @@ -82,6 +82,7 @@ builtin_dsimproc [simp, seval] reducePow ((_ : Int) ^ (_ : Nat)) := fun e => do let_expr HPow.hPow _ _ _ _ a b ← e | return .continue let some v₁ ← fromExpr? a | return .continue let some v₂ ← Nat.fromExpr? b | return .continue + unless (← checkExponent v₂) do return .continue return .done <| toExpr (v₁ ^ v₂) builtin_simproc [simp, seval] reduceLT (( _ : Int) < _) := reduceBinPred ``LT.lt 4 (. < .) diff --git a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean index 23c0d1b90b..2a1a4163fb 100644 --- a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean +++ b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Init.Simproc import Init.Data.Nat.Simproc +import Lean.Util.SafeExponentiation import Lean.Meta.LitValues import Lean.Meta.Offset import Lean.Meta.Tactic.Simp.Simproc @@ -52,7 +53,13 @@ builtin_dsimproc [simp, seval] reduceMul ((_ * _ : Nat)) := reduceBin ``HMul.hMu builtin_dsimproc [simp, seval] reduceSub ((_ - _ : Nat)) := reduceBin ``HSub.hSub 6 (· - ·) builtin_dsimproc [simp, seval] reduceDiv ((_ / _ : Nat)) := reduceBin ``HDiv.hDiv 6 (· / ·) builtin_dsimproc [simp, seval] reduceMod ((_ % _ : Nat)) := reduceBin ``HMod.hMod 6 (· % ·) -builtin_dsimproc [simp, seval] reducePow ((_ ^ _ : Nat)) := reduceBin ``HPow.hPow 6 (· ^ ·) + +builtin_dsimproc [simp, seval] reducePow ((_ ^ _ : Nat)) := fun e => do + let some n ← fromExpr? e.appFn!.appArg! | return .continue + let some m ← fromExpr? e.appArg! | return .continue + unless (← checkExponent m) do return .continue + return .done <| toExpr (n ^ m) + builtin_dsimproc [simp, seval] reduceGcd (gcd _ _) := reduceBin ``gcd 2 gcd builtin_simproc [simp, seval] reduceLT (( _ : Nat) < _) := reduceBinPred ``LT.lt 4 (. < .) diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index ebd379279b..f1c7915934 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Lean.Structure import Lean.Util.Recognizers +import Lean.Util.SafeExponentiation import Lean.Meta.GetUnfoldableConst import Lean.Meta.FunInfo import Lean.Meta.Offset @@ -885,6 +886,13 @@ def reduceBinNatOp (f : Nat → Nat → Nat) (a b : Expr) : MetaM (Option Expr) trace[Meta.isDefEq.whnf.reduceBinOp] "{a} op {b}" return mkRawNatLit <| f a b +def reducePow (a b : Expr) : MetaM (Option Expr) := + withNatValue a fun a => + withNatValue b fun b => OptionT.run do + guard (← checkExponent b) + trace[Meta.isDefEq.whnf.reduceBinOp] "{a} ^ {b}" + return mkRawNatLit <| a ^ b + def reduceBinNatPred (f : Nat → Nat → Bool) (a b : Expr) : MetaM (Option Expr) := do withNatValue a fun a => withNatValue b fun b => @@ -904,7 +912,7 @@ def reduceNat? (e : Expr) : MetaM (Option Expr) := | ``Nat.mul => reduceBinNatOp Nat.mul a1 a2 | ``Nat.div => reduceBinNatOp Nat.div a1 a2 | ``Nat.mod => reduceBinNatOp Nat.mod a1 a2 - | ``Nat.pow => reduceBinNatOp Nat.pow a1 a2 + | ``Nat.pow => reducePow a1 a2 | ``Nat.gcd => reduceBinNatOp Nat.gcd a1 a2 | ``Nat.beq => reduceBinNatPred Nat.beq a1 a2 | ``Nat.ble => reduceBinNatPred Nat.ble a1 a2 diff --git a/src/Lean/Util.lean b/src/Lean/Util.lean index 9df7ef2f8a..dbdfa95534 100644 --- a/src/Lean/Util.lean +++ b/src/Lean/Util.lean @@ -29,3 +29,4 @@ import Lean.Util.OccursCheck import Lean.Util.HasConstCache import Lean.Util.FileSetupInfo import Lean.Util.Heartbeats +import Lean.Util.SafeExponentiation diff --git a/src/Lean/Util/SafeExponentiation.lean b/src/Lean/Util/SafeExponentiation.lean new file mode 100644 index 0000000000..4a402e6e4f --- /dev/null +++ b/src/Lean/Util/SafeExponentiation.lean @@ -0,0 +1,34 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.CoreM + +namespace Lean + +register_builtin_option exponentiation.threshold : Nat := { + defValue := 256 + descr := "maximum value for \ + which exponentiation operations are safe to evaluate. When an exponent \ + is a value greater than this threshold, the exponentiation will not be evaluated, \ + and a warning will be logged. This helps to prevent the system from becoming \ + unresponsive due to excessively large computations." +} + +/-- +Returns `true` if `n` is `≤ exponentiation.threshold`. Otherwise, +reports a warning and returns `false`. +This method ensures there is at most one warning message of this kind in the message log. +-/ +def checkExponent (n : Nat) : CoreM Bool := do + let threshold := exponentiation.threshold.get (← getOptions) + if n > threshold then + if (← reportMessageKind `unsafe.exponentiation) then + logWarning s!"exponent {n} exceeds the threshold {threshold}, exponentiation operation was not evaluated, use `set_option {exponentiation.threshold.name} ` to set a new threshold" + return false + else + return true + +end Lean diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 1dbb0819a9..013c960f64 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -595,6 +595,18 @@ template optional type_checker::reduce_bin_nat_op(F const & f, return some_expr(mk_lit(literal(nat(f(v1.raw(), v2.raw()))))); } +#define ReducePowMaxExp 1<<24 // TODO: make it configurable + +optional type_checker::reduce_pow(expr const & e) { + expr arg1 = whnf(app_arg(app_fn(e))); + expr arg2 = whnf(app_arg(e)); + if (!is_nat_lit_ext(arg2)) return none_expr(); + nat v1 = get_nat_val(arg1); + nat v2 = get_nat_val(arg2); + if (v2 > nat(ReducePowMaxExp)) return none_expr(); + return some_expr(mk_lit(literal(nat(nat_pow(v1.raw(), v2.raw()))))); +} + template optional type_checker::reduce_bin_nat_pred(F const & f, expr const & e) { expr arg1 = whnf(app_arg(app_fn(e))); if (!is_nat_lit_ext(arg1)) return none_expr(); @@ -622,7 +634,7 @@ optional type_checker::reduce_nat(expr const & e) { if (f == *g_nat_add) return reduce_bin_nat_op(nat_add, e); if (f == *g_nat_sub) return reduce_bin_nat_op(nat_sub, e); if (f == *g_nat_mul) return reduce_bin_nat_op(nat_mul, e); - if (f == *g_nat_pow) return reduce_bin_nat_op(nat_pow, e); + if (f == *g_nat_pow) return reduce_pow(e); if (f == *g_nat_gcd) return reduce_bin_nat_op(nat_gcd, e); if (f == *g_nat_mod) return reduce_bin_nat_op(nat_mod, e); if (f == *g_nat_div) return reduce_bin_nat_op(nat_div, e); diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index e281c96474..3ab2ce643b 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -101,6 +101,7 @@ private: template optional reduce_bin_nat_op(F const & f, expr const & e); template optional reduce_bin_nat_pred(F const & f, expr const & e); + optional reduce_pow(expr const & e); optional reduce_nat(expr const & e); public: type_checker(state & st, local_ctx const & lctx, definition_safety ds = definition_safety::safe); diff --git a/tests/lean/run/lean_nat_gcd.lean b/tests/lean/run/lean_nat_gcd.lean index cb91e6e53c..4cb2b47278 100644 --- a/tests/lean/run/lean_nat_gcd.lean +++ b/tests/lean/run/lean_nat_gcd.lean @@ -49,6 +49,7 @@ def p_31 := 216091 def p_32 := 756839 def p_33 := 859433 +set_option exponentiation.threshold 10000000 /- GCD with large prime factors on one side, and small primes on the other. -/ example : Nat.gcd (p_29 * p_30 * p_31 * p_32 * p_33) 2^(2^20) = 1 := rfl /- GCD with two prime factors on both sides, including one in common. -/ diff --git a/tests/lean/run/safeExp.lean b/tests/lean/run/safeExp.lean new file mode 100644 index 0000000000..ef27868f32 --- /dev/null +++ b/tests/lean/run/safeExp.lean @@ -0,0 +1,47 @@ +/-- +warning: exponent 10000000 exceeds the threshold 256, exponentiation operation was not evaluated, use `set_option exponentiation.threshold ` to set a new threshold +--- +error: maximum recursion depth has been reached +use `set_option maxRecDepth ` to increase limit +use `set_option diagnostics true` to get diagnostic information +-/ +#guard_msgs in +example : 2^2^8000000 = 3^3^10000000 := + rfl + +/-- +-/ +#guard_msgs in +set_option exponentiation.threshold 258 in +example : 2^257 = 2*2^256 := + rfl + +/-- +warning: exponent 2008 exceeds the threshold 256, exponentiation operation was not evaluated, use `set_option exponentiation.threshold ` to set a new threshold +--- +warning: declaration uses 'sorry' +--- +error: (kernel) deep recursion detected +--- +info: k : Nat +h : k = 2008 ^ 2 + 2 ^ 2008 +⊢ ((4032064 + 2 ^ 2008) ^ 2 + 2 ^ (4032064 + 2 ^ 2008)) % 10 = 6 +-/ +#guard_msgs in +example (k : Nat) (h : k = 2008^2 + 2^2008) : (k^2 + 2^k)%10 = 6 := by + simp [h] + trace_state + sorry + +/-- +warning: declaration uses 'sorry' +--- +info: k : Nat +h : k = 2008 ^ 2 + 2 ^ 2008 +⊢ ((2008 ^ 2 + 2 ^ 2008) ^ 2 + 2 ^ (2008 ^ 2 + 2 ^ 2008)) % 10 = 6 +-/ +#guard_msgs in +example (k : Nat) (h : k = 2008^2 + 2^2008) : (k^2 + 2^k)%10 = 6 := by + rw [h] + trace_state + sorry