From e2b3b34d14879ce02a6acf1c8599d06f64258177 Mon Sep 17 00:00:00 2001 From: Joe Hendrix Date: Mon, 19 Feb 2024 07:04:51 -0800 Subject: [PATCH] feat: introduce native functions for Int.ediv / Int.emod (#3376) These still need tests, but I thought I'd upstream so I can use benchmarking and check for build errors. --- src/Init/Data/Int/DivMod.lean | 6 +- src/include/lean/lean.h | 77 +++++++++++++++++++++ src/runtime/mpz.cpp | 85 ++++++++++++++++++++++++ src/runtime/mpz.h | 8 +++ src/runtime/object.cpp | 30 +++++++++ tests/lean/int_div_mod.lean | 62 +++++++++++++++++ tests/lean/int_div_mod.lean.expected.out | 0 7 files changed, 266 insertions(+), 2 deletions(-) create mode 100644 tests/lean/int_div_mod.lean create mode 100644 tests/lean/int_div_mod.lean.expected.out diff --git a/src/Init/Data/Int/DivMod.lean b/src/Init/Data/Int/DivMod.lean index 8a7bf71aeb..16f0bddc5d 100644 --- a/src/Init/Data/Int/DivMod.lean +++ b/src/Init/Data/Int/DivMod.lean @@ -131,7 +131,8 @@ Integer division. This version of `Int.div` uses the E-rounding convention (euclidean division), in which `Int.emod x y` satisfies `0 ≤ mod x y < natAbs y` for `y ≠ 0` and `Int.ediv` is the unique function satisfying `emod x y + (ediv x y) * y = x`. -/ -def ediv : Int → Int → Int +@[extern "lean_int_ediv"] +def ediv : (@& Int) → (@& Int) → Int | ofNat m, ofNat n => ofNat (m / n) | ofNat m, -[n+1] => -ofNat (m / succ n) | -[_+1], 0 => 0 @@ -143,7 +144,8 @@ Integer modulus. This version of `Int.mod` uses the E-rounding convention (euclidean division), in which `Int.emod x y` satisfies `0 ≤ emod x y < natAbs y` for `y ≠ 0` and `Int.ediv` is the unique function satisfying `emod x y + (ediv x y) * y = x`. -/ -def emod : Int → Int → Int +@[extern "lean_int_emod"] +def emod : (@& Int) → (@& Int) → Int | ofNat m, n => ofNat (m % natAbs n) | -[m+1], n => subNatNat (natAbs n) (succ (m % natAbs n)) diff --git a/src/include/lean/lean.h b/src/include/lean/lean.h index 7b1f320b42..23d7abc83d 100644 --- a/src/include/lean/lean.h +++ b/src/include/lean/lean.h @@ -1320,6 +1320,8 @@ LEAN_SHARED lean_object * lean_int_big_sub(lean_object * a1, lean_object * a2); LEAN_SHARED lean_object * lean_int_big_mul(lean_object * a1, lean_object * a2); LEAN_SHARED lean_object * lean_int_big_div(lean_object * a1, lean_object * a2); LEAN_SHARED lean_object * lean_int_big_mod(lean_object * a1, lean_object * a2); +LEAN_SHARED lean_object * lean_int_big_ediv(lean_object * a1, lean_object * a2); +LEAN_SHARED lean_object * lean_int_big_emod(lean_object * a1, lean_object * a2); LEAN_SHARED bool lean_int_big_eq(lean_object * a1, lean_object * a2); LEAN_SHARED bool lean_int_big_le(lean_object * a1, lean_object * a2); LEAN_SHARED bool lean_int_big_lt(lean_object * a1, lean_object * a2); @@ -1461,6 +1463,81 @@ static inline lean_obj_res lean_int_mod(b_lean_obj_arg a1, b_lean_obj_arg a2) { } } +/* +lean_int_ediv and lean_int_emod implement "Euclidean" division and modulus using the +algorithm in: + Division and Modulus for Computer Scientists + Daan Leijen + https://www.microsoft.com/en-us/research/publication/division-and-modulus-for-computer-scientists/ + +*/ + +static inline lean_obj_res lean_int_ediv(b_lean_obj_arg a1, b_lean_obj_arg a2) { + if (LEAN_LIKELY(lean_is_scalar(a1) && lean_is_scalar(a2))) { + if (sizeof(void*) == 8) { + /* 64-bit version, we use 64-bit numbers to avoid overflow when v1 == LEAN_MIN_SMALL_INT. */ + int64_t n = lean_scalar_to_int(a1); + int64_t d = lean_scalar_to_int(a2); + if (d == 0) + return lean_box(0); + else { + int64_t q = n / d; + int64_t r = n % d; + if (r < 0) + q = (d > 0) ? q - 1 : q + 1; + return lean_int64_to_int(q); + } + } else { + /* 32-bit version */ + int n = lean_scalar_to_int(a1); + int d = lean_scalar_to_int(a2); + if (d == 0) { + return lean_box(0); + } else { + int q = n / d; + int r = n % d; + if (r < 0) + q = (d > 0) ? q - 1 : q + 1; + return lean_int_to_int(q); + } + } + } else { + return lean_int_big_ediv(a1, a2); + } +} + +static inline lean_obj_res lean_int_emod(b_lean_obj_arg a1, b_lean_obj_arg a2) { + if (LEAN_LIKELY(lean_is_scalar(a1) && lean_is_scalar(a2))) { + if (sizeof(void*) == 8) { + /* 64-bit version, we use 64-bit numbers to avoid overflow when v1 == LEAN_MIN_SMALL_INT. */ + int64_t n = lean_scalar_to_int64(a1); + int64_t d = lean_scalar_to_int64(a2); + if (d == 0) { + return a1; + } else { + int64_t r = n % d; + if (r < 0) + r = (d > 0) ? r + d : r - d; + return lean_int64_to_int(r); + } + } else { + /* 32-bit version */ + int n = lean_scalar_to_int(a1); + int d = lean_scalar_to_int(a2); + if (d == 0) + return a1; + else { + int r = n % d; + if (r < 0) + r = (d > 0) ? r + d : r - d; + return lean_int_to_int(r); + } + } + } else { + return lean_int_big_emod(a1, a2); + } +} + static inline bool lean_int_eq(b_lean_obj_arg a1, b_lean_obj_arg a2) { if (LEAN_LIKELY(lean_is_scalar(a1) && lean_is_scalar(a2))) { return a1 == a2; diff --git a/src/runtime/mpz.cpp b/src/runtime/mpz.cpp index c9183f7846..8dc4a71c7d 100644 --- a/src/runtime/mpz.cpp +++ b/src/runtime/mpz.cpp @@ -160,6 +160,43 @@ mpz & mpz::operator*=(unsigned u) { mpz_mul_ui(m_val, m_val, u); return *this; } mpz & mpz::operator*=(int u) { mpz_mul_si(m_val, m_val, u); return *this; } +mpz mpz::ediv(mpz const & n, mpz const & d) { + mpz q; + mpz_t r; + mpz_init(r); + /* (q,r) = (n/d, n%d) */ + mpz_tdiv_qr(q.m_val, r, n.m_val, d.m_val); + /* if (r < 0) */ + if (mpz_sgn(r) < 0) { + if (mpz_sgn(d.m_val) > 0) { + /* q = q - 1. */ + mpz_sub_ui(q.m_val, q.m_val, 1); + } else { + /* q = q + 1. */ + mpz_add_ui(q.m_val, q.m_val, 1); + } + } + mpz_clear(r); + return q; +} + +mpz mpz::emod(mpz const & n, mpz const & d) { + mpz r; + /* (q,r) = (n/d, n%d) */ + mpz_tdiv_r(r.m_val, n.m_val, d.m_val); + /* if (r < 0) */ + if (mpz_sgn(r.m_val) < 0) { + if (mpz_sgn(d.m_val) > 0) { + /* r = r + d. */ + mpz_add(r.m_val, r.m_val, d.m_val); + } else { + /* r = r - d. */ + mpz_sub(r.m_val, r.m_val, d.m_val); + } + } + return r; +} + mpz & mpz::operator/=(mpz const & o) { mpz_tdiv_q(m_val, m_val, o.m_val); return *this; } mpz & mpz::operator/=(unsigned u) { mpz_tdiv_q_ui(m_val, m_val, u); return *this; } @@ -630,6 +667,7 @@ mpz & mpz::rem(size_t sz, mpn_digit const * digits) { digits, sz, q1.begin(), r1.begin()); set(r_sz, r1.begin()); + m_sign = m_sign && !is_zero(); return *this; } @@ -699,6 +737,53 @@ mpz & mpz::operator%=(mpz const & o) { return rem(o.m_size, o.m_digits); } +mpz mpz::ediv(mpz const & n, mpz const & d) { + if (d.m_size > n.m_size) { + if (n.is_neg()) { + int64_t r = d.is_pos() ? -1 : 1; + return mpz(r); + } else { + return mpz(0); + } + } else { + digit_buffer q1, r1; + size_t q_sz = n.m_size - d.m_size + 1; + size_t r_sz = d.m_size; + q1.ensure_capacity(q_sz); + r1.ensure_capacity(r_sz); + mpn_div(n.m_digits, n.m_size, + d.m_digits, d.m_size, + q1.begin(), r1.begin()); + mpz q; + q.set(q_sz, q1.begin()); + q.m_sign = !q.is_zero() && n.m_sign != d.m_sign; + mpz r; + r.set(r_sz, r1.begin()); + r.m_sign = n.m_sign && !r.is_zero(); + if (r.is_neg()) { + if (d.is_pos()) { + q -= 1; + } else { + q += 1; + } + } + return q; + } +} + +mpz mpz::emod(mpz const & n, mpz const & d) { + mpz r(n); + r.rem(d.m_size, d.m_digits); + if (r.is_neg()) { + if (d.is_pos()) { + r += d; + } else { + r -= d; + } + } + return r; +} + mpz mpz::pow(unsigned int p) const { unsigned mask = 1; mpz power(*this); diff --git a/src/runtime/mpz.h b/src/runtime/mpz.h index 2b30d05ba4..e0680916f4 100644 --- a/src/runtime/mpz.h +++ b/src/runtime/mpz.h @@ -245,6 +245,14 @@ public: friend mpz operator%(mpz a, mpz const & b) { return a %= b; } + static mpz ediv(mpz const & n, mpz const & d); + static mpz ediv(int n, mpz const & d) { return ediv(mpz(n), d); } + static mpz ediv(mpz const& n, int d) { return ediv(n, mpz(d)); } + + static mpz emod(mpz const & n, mpz const & d); + static mpz emod(int n, mpz const & d) { return emod(mpz(n), d); } + static mpz emod(mpz const & n, int d) { return emod(n, mpz(d)); }; + mpz & operator&=(mpz const & o); mpz & operator|=(mpz const & o); mpz & operator^=(mpz const & o); diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 3077a8a4fd..2106325d0e 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -1432,6 +1432,36 @@ extern "C" LEAN_EXPORT object * lean_int_big_mod(object * a1, object * a2) { } } +extern "C" LEAN_EXPORT object * lean_int_big_ediv(object * a1, object * a2) { + if (lean_is_scalar(a1)) { + return mpz_to_int(mpz::ediv(lean_scalar_to_int(a1), mpz_value(a2))); + } else if (lean_is_scalar(a2)) { + int d = lean_scalar_to_int(a2); + if (d == 0) + return a2; + else + return mpz_to_int(mpz::ediv(mpz_value(a1), d)); + } else { + return mpz_to_int(mpz::ediv(mpz_value(a1), mpz_value(a2))); + } +} + +extern "C" LEAN_EXPORT object * lean_int_big_emod(object * a1, object * a2) { + if (lean_is_scalar(a1)) { + return mpz_to_int(mpz::emod(lean_scalar_to_int(a1), mpz_value(a2))); + } else if (lean_is_scalar(a2)) { + int i2 = lean_scalar_to_int(a2); + if (i2 == 0) { + lean_inc(a1); + return a1; + } else { + return mpz_to_int(mpz::emod(mpz_value(a1), i2)); + } + } else { + return mpz_to_int(mpz::emod(mpz_value(a1), mpz_value(a2))); + } +} + extern "C" LEAN_EXPORT bool lean_int_big_eq(object * a1, object * a2) { if (lean_is_scalar(a1)) { lean_assert(lean_scalar_to_int(a1) != mpz_value(a2)) diff --git a/tests/lean/int_div_mod.lean b/tests/lean/int_div_mod.lean new file mode 100644 index 0000000000..6bc9b1a323 --- /dev/null +++ b/tests/lean/int_div_mod.lean @@ -0,0 +1,62 @@ +-- Divide by zero tests +#guard ( 0 : Int) / 0 = 0 +#guard ( 0 : Int) % 0 == 0 +#guard ( 4 : Int) / 0 == 0 +#guard ( 4 : Int) % 0 == 4 +#guard (-4 : Int) / 0 == 0 +#guard (-4 : Int) % 0 == -4 +#guard ( 0 : Int) / 4 == 0 +#guard ( 0 : Int) % 4 == 0 +#guard ( 0 : Int) / -4 == 0 +#guard ( 0 : Int) % -4 == 0 + +-- Euclidean division tests +#guard ( 4 : Int) / 3 == 1 +#guard ( 4 : Int) % 3 == 1 +#guard ( 5 : Int) / 3 == 1 +#guard ( 5 : Int) % 3 == 2 +#guard ( 6 : Int) / 3 == 2 +#guard ( 6 : Int) % 3 == 0 +#guard ( 7 : Int) / 4 == 1 +#guard ( 7 : Int) % 4 == 3 + +#guard ( 4 : Int) / -3 == -1 +#guard ( 4 : Int) % -3 == 1 +#guard ( 5 : Int) / -3 == -1 +#guard ( 5 : Int) % -3 == 2 +#guard ( 6 : Int) / -3 == -2 +#guard ( 6 : Int) % -3 == 0 +#guard ( 7 : Int) / -4 == -1 +#guard ( 7 : Int) % -4 == 3 + +#guard (-4 : Int) / 3 == -2 +#guard (-4 : Int) % 3 == 2 +#guard (-5 : Int) / 3 == -2 +#guard (-5 : Int) % 3 == 1 +#guard (-6 : Int) / 3 == -2 +#guard (-6 : Int) % 3 == 0 +#guard (-7 : Int) / 4 == -2 +#guard (-7 : Int) % 4 == 1 + +#guard (-4 : Int) / -3 == 2 +#guard (-4 : Int) % -3 == 2 +#guard (-5 : Int) / -3 == 2 +#guard (-5 : Int) % -3 == 1 +#guard (-6 : Int) / -3 == 2 +#guard (-6 : Int) % -3 == 0 +#guard (-7 : Int) / -4 == 2 +#guard (-7 : Int) % -4 == 1 + +-- Basic big integer tests +#guard let n : Int := 0; let d : Int := 2^64; n / d = 0 ∧ n % d = n +#guard let n : Int := 1; let d : Int := 2^64; n / d = 0 ∧ n % d = n +#guard let n : Int := -1; let d : Int := 2^64; n / d = -1 ∧ n % d = (d + n) +#guard let n : Int := 2^128; let d : Int := 3; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d +#guard let n : Int := 2^128; let d : Int := 2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d +#guard let n : Int := -2^128; let d : Int := 2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d +#guard let n : Int := 2^128; let d : Int := -2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d.natAbs +#guard let n : Int := -2^128; let d : Int := -2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d.natAbs +#guard let n : Int := 2^128+7; let d : Int := 2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d +#guard let n : Int := -2^128+3; let d : Int := 2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d +#guard let n : Int := 2^128+2; let d : Int := -2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d.natAbs +#guard let n : Int := -2^128+2; let d : Int := -2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d.natAbs diff --git a/tests/lean/int_div_mod.lean.expected.out b/tests/lean/int_div_mod.lean.expected.out new file mode 100644 index 0000000000..e69de29bb2