feat: introduce native functions for Int.ediv / Int.emod (#3376)
These still need tests, but I thought I'd upstream so I can use benchmarking and check for build errors.
This commit is contained in:
parent
204b408df7
commit
e2b3b34d14
7 changed files with 266 additions and 2 deletions
|
|
@ -131,7 +131,8 @@ Integer division. This version of `Int.div` uses the E-rounding convention
|
|||
(euclidean division), in which `Int.emod x y` satisfies `0 ≤ mod x y < natAbs y` for `y ≠ 0`
|
||||
and `Int.ediv` is the unique function satisfying `emod x y + (ediv x y) * y = x`.
|
||||
-/
|
||||
def ediv : Int → Int → Int
|
||||
@[extern "lean_int_ediv"]
|
||||
def ediv : (@& Int) → (@& Int) → Int
|
||||
| ofNat m, ofNat n => ofNat (m / n)
|
||||
| ofNat m, -[n+1] => -ofNat (m / succ n)
|
||||
| -[_+1], 0 => 0
|
||||
|
|
@ -143,7 +144,8 @@ Integer modulus. This version of `Int.mod` uses the E-rounding convention
|
|||
(euclidean division), in which `Int.emod x y` satisfies `0 ≤ emod x y < natAbs y` for `y ≠ 0`
|
||||
and `Int.ediv` is the unique function satisfying `emod x y + (ediv x y) * y = x`.
|
||||
-/
|
||||
def emod : Int → Int → Int
|
||||
@[extern "lean_int_emod"]
|
||||
def emod : (@& Int) → (@& Int) → Int
|
||||
| ofNat m, n => ofNat (m % natAbs n)
|
||||
| -[m+1], n => subNatNat (natAbs n) (succ (m % natAbs n))
|
||||
|
||||
|
|
|
|||
|
|
@ -1320,6 +1320,8 @@ LEAN_SHARED lean_object * lean_int_big_sub(lean_object * a1, lean_object * a2);
|
|||
LEAN_SHARED lean_object * lean_int_big_mul(lean_object * a1, lean_object * a2);
|
||||
LEAN_SHARED lean_object * lean_int_big_div(lean_object * a1, lean_object * a2);
|
||||
LEAN_SHARED lean_object * lean_int_big_mod(lean_object * a1, lean_object * a2);
|
||||
LEAN_SHARED lean_object * lean_int_big_ediv(lean_object * a1, lean_object * a2);
|
||||
LEAN_SHARED lean_object * lean_int_big_emod(lean_object * a1, lean_object * a2);
|
||||
LEAN_SHARED bool lean_int_big_eq(lean_object * a1, lean_object * a2);
|
||||
LEAN_SHARED bool lean_int_big_le(lean_object * a1, lean_object * a2);
|
||||
LEAN_SHARED bool lean_int_big_lt(lean_object * a1, lean_object * a2);
|
||||
|
|
@ -1461,6 +1463,81 @@ static inline lean_obj_res lean_int_mod(b_lean_obj_arg a1, b_lean_obj_arg a2) {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
lean_int_ediv and lean_int_emod implement "Euclidean" division and modulus using the
|
||||
algorithm in:
|
||||
Division and Modulus for Computer Scientists
|
||||
Daan Leijen
|
||||
https://www.microsoft.com/en-us/research/publication/division-and-modulus-for-computer-scientists/
|
||||
|
||||
*/
|
||||
|
||||
static inline lean_obj_res lean_int_ediv(b_lean_obj_arg a1, b_lean_obj_arg a2) {
|
||||
if (LEAN_LIKELY(lean_is_scalar(a1) && lean_is_scalar(a2))) {
|
||||
if (sizeof(void*) == 8) {
|
||||
/* 64-bit version, we use 64-bit numbers to avoid overflow when v1 == LEAN_MIN_SMALL_INT. */
|
||||
int64_t n = lean_scalar_to_int(a1);
|
||||
int64_t d = lean_scalar_to_int(a2);
|
||||
if (d == 0)
|
||||
return lean_box(0);
|
||||
else {
|
||||
int64_t q = n / d;
|
||||
int64_t r = n % d;
|
||||
if (r < 0)
|
||||
q = (d > 0) ? q - 1 : q + 1;
|
||||
return lean_int64_to_int(q);
|
||||
}
|
||||
} else {
|
||||
/* 32-bit version */
|
||||
int n = lean_scalar_to_int(a1);
|
||||
int d = lean_scalar_to_int(a2);
|
||||
if (d == 0) {
|
||||
return lean_box(0);
|
||||
} else {
|
||||
int q = n / d;
|
||||
int r = n % d;
|
||||
if (r < 0)
|
||||
q = (d > 0) ? q - 1 : q + 1;
|
||||
return lean_int_to_int(q);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return lean_int_big_ediv(a1, a2);
|
||||
}
|
||||
}
|
||||
|
||||
static inline lean_obj_res lean_int_emod(b_lean_obj_arg a1, b_lean_obj_arg a2) {
|
||||
if (LEAN_LIKELY(lean_is_scalar(a1) && lean_is_scalar(a2))) {
|
||||
if (sizeof(void*) == 8) {
|
||||
/* 64-bit version, we use 64-bit numbers to avoid overflow when v1 == LEAN_MIN_SMALL_INT. */
|
||||
int64_t n = lean_scalar_to_int64(a1);
|
||||
int64_t d = lean_scalar_to_int64(a2);
|
||||
if (d == 0) {
|
||||
return a1;
|
||||
} else {
|
||||
int64_t r = n % d;
|
||||
if (r < 0)
|
||||
r = (d > 0) ? r + d : r - d;
|
||||
return lean_int64_to_int(r);
|
||||
}
|
||||
} else {
|
||||
/* 32-bit version */
|
||||
int n = lean_scalar_to_int(a1);
|
||||
int d = lean_scalar_to_int(a2);
|
||||
if (d == 0)
|
||||
return a1;
|
||||
else {
|
||||
int r = n % d;
|
||||
if (r < 0)
|
||||
r = (d > 0) ? r + d : r - d;
|
||||
return lean_int_to_int(r);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return lean_int_big_emod(a1, a2);
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool lean_int_eq(b_lean_obj_arg a1, b_lean_obj_arg a2) {
|
||||
if (LEAN_LIKELY(lean_is_scalar(a1) && lean_is_scalar(a2))) {
|
||||
return a1 == a2;
|
||||
|
|
|
|||
|
|
@ -160,6 +160,43 @@ mpz & mpz::operator*=(unsigned u) { mpz_mul_ui(m_val, m_val, u); return *this; }
|
|||
|
||||
mpz & mpz::operator*=(int u) { mpz_mul_si(m_val, m_val, u); return *this; }
|
||||
|
||||
mpz mpz::ediv(mpz const & n, mpz const & d) {
|
||||
mpz q;
|
||||
mpz_t r;
|
||||
mpz_init(r);
|
||||
/* (q,r) = (n/d, n%d) */
|
||||
mpz_tdiv_qr(q.m_val, r, n.m_val, d.m_val);
|
||||
/* if (r < 0) */
|
||||
if (mpz_sgn(r) < 0) {
|
||||
if (mpz_sgn(d.m_val) > 0) {
|
||||
/* q = q - 1. */
|
||||
mpz_sub_ui(q.m_val, q.m_val, 1);
|
||||
} else {
|
||||
/* q = q + 1. */
|
||||
mpz_add_ui(q.m_val, q.m_val, 1);
|
||||
}
|
||||
}
|
||||
mpz_clear(r);
|
||||
return q;
|
||||
}
|
||||
|
||||
mpz mpz::emod(mpz const & n, mpz const & d) {
|
||||
mpz r;
|
||||
/* (q,r) = (n/d, n%d) */
|
||||
mpz_tdiv_r(r.m_val, n.m_val, d.m_val);
|
||||
/* if (r < 0) */
|
||||
if (mpz_sgn(r.m_val) < 0) {
|
||||
if (mpz_sgn(d.m_val) > 0) {
|
||||
/* r = r + d. */
|
||||
mpz_add(r.m_val, r.m_val, d.m_val);
|
||||
} else {
|
||||
/* r = r - d. */
|
||||
mpz_sub(r.m_val, r.m_val, d.m_val);
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
mpz & mpz::operator/=(mpz const & o) { mpz_tdiv_q(m_val, m_val, o.m_val); return *this; }
|
||||
mpz & mpz::operator/=(unsigned u) { mpz_tdiv_q_ui(m_val, m_val, u); return *this; }
|
||||
|
||||
|
|
@ -630,6 +667,7 @@ mpz & mpz::rem(size_t sz, mpn_digit const * digits) {
|
|||
digits, sz,
|
||||
q1.begin(), r1.begin());
|
||||
set(r_sz, r1.begin());
|
||||
m_sign = m_sign && !is_zero();
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
@ -699,6 +737,53 @@ mpz & mpz::operator%=(mpz const & o) {
|
|||
return rem(o.m_size, o.m_digits);
|
||||
}
|
||||
|
||||
mpz mpz::ediv(mpz const & n, mpz const & d) {
|
||||
if (d.m_size > n.m_size) {
|
||||
if (n.is_neg()) {
|
||||
int64_t r = d.is_pos() ? -1 : 1;
|
||||
return mpz(r);
|
||||
} else {
|
||||
return mpz(0);
|
||||
}
|
||||
} else {
|
||||
digit_buffer q1, r1;
|
||||
size_t q_sz = n.m_size - d.m_size + 1;
|
||||
size_t r_sz = d.m_size;
|
||||
q1.ensure_capacity(q_sz);
|
||||
r1.ensure_capacity(r_sz);
|
||||
mpn_div(n.m_digits, n.m_size,
|
||||
d.m_digits, d.m_size,
|
||||
q1.begin(), r1.begin());
|
||||
mpz q;
|
||||
q.set(q_sz, q1.begin());
|
||||
q.m_sign = !q.is_zero() && n.m_sign != d.m_sign;
|
||||
mpz r;
|
||||
r.set(r_sz, r1.begin());
|
||||
r.m_sign = n.m_sign && !r.is_zero();
|
||||
if (r.is_neg()) {
|
||||
if (d.is_pos()) {
|
||||
q -= 1;
|
||||
} else {
|
||||
q += 1;
|
||||
}
|
||||
}
|
||||
return q;
|
||||
}
|
||||
}
|
||||
|
||||
mpz mpz::emod(mpz const & n, mpz const & d) {
|
||||
mpz r(n);
|
||||
r.rem(d.m_size, d.m_digits);
|
||||
if (r.is_neg()) {
|
||||
if (d.is_pos()) {
|
||||
r += d;
|
||||
} else {
|
||||
r -= d;
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
mpz mpz::pow(unsigned int p) const {
|
||||
unsigned mask = 1;
|
||||
mpz power(*this);
|
||||
|
|
|
|||
|
|
@ -245,6 +245,14 @@ public:
|
|||
|
||||
friend mpz operator%(mpz a, mpz const & b) { return a %= b; }
|
||||
|
||||
static mpz ediv(mpz const & n, mpz const & d);
|
||||
static mpz ediv(int n, mpz const & d) { return ediv(mpz(n), d); }
|
||||
static mpz ediv(mpz const& n, int d) { return ediv(n, mpz(d)); }
|
||||
|
||||
static mpz emod(mpz const & n, mpz const & d);
|
||||
static mpz emod(int n, mpz const & d) { return emod(mpz(n), d); }
|
||||
static mpz emod(mpz const & n, int d) { return emod(n, mpz(d)); };
|
||||
|
||||
mpz & operator&=(mpz const & o);
|
||||
mpz & operator|=(mpz const & o);
|
||||
mpz & operator^=(mpz const & o);
|
||||
|
|
|
|||
|
|
@ -1432,6 +1432,36 @@ extern "C" LEAN_EXPORT object * lean_int_big_mod(object * a1, object * a2) {
|
|||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_int_big_ediv(object * a1, object * a2) {
|
||||
if (lean_is_scalar(a1)) {
|
||||
return mpz_to_int(mpz::ediv(lean_scalar_to_int(a1), mpz_value(a2)));
|
||||
} else if (lean_is_scalar(a2)) {
|
||||
int d = lean_scalar_to_int(a2);
|
||||
if (d == 0)
|
||||
return a2;
|
||||
else
|
||||
return mpz_to_int(mpz::ediv(mpz_value(a1), d));
|
||||
} else {
|
||||
return mpz_to_int(mpz::ediv(mpz_value(a1), mpz_value(a2)));
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_int_big_emod(object * a1, object * a2) {
|
||||
if (lean_is_scalar(a1)) {
|
||||
return mpz_to_int(mpz::emod(lean_scalar_to_int(a1), mpz_value(a2)));
|
||||
} else if (lean_is_scalar(a2)) {
|
||||
int i2 = lean_scalar_to_int(a2);
|
||||
if (i2 == 0) {
|
||||
lean_inc(a1);
|
||||
return a1;
|
||||
} else {
|
||||
return mpz_to_int(mpz::emod(mpz_value(a1), i2));
|
||||
}
|
||||
} else {
|
||||
return mpz_to_int(mpz::emod(mpz_value(a1), mpz_value(a2)));
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT bool lean_int_big_eq(object * a1, object * a2) {
|
||||
if (lean_is_scalar(a1)) {
|
||||
lean_assert(lean_scalar_to_int(a1) != mpz_value(a2))
|
||||
|
|
|
|||
62
tests/lean/int_div_mod.lean
Normal file
62
tests/lean/int_div_mod.lean
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
-- Divide by zero tests
|
||||
#guard ( 0 : Int) / 0 = 0
|
||||
#guard ( 0 : Int) % 0 == 0
|
||||
#guard ( 4 : Int) / 0 == 0
|
||||
#guard ( 4 : Int) % 0 == 4
|
||||
#guard (-4 : Int) / 0 == 0
|
||||
#guard (-4 : Int) % 0 == -4
|
||||
#guard ( 0 : Int) / 4 == 0
|
||||
#guard ( 0 : Int) % 4 == 0
|
||||
#guard ( 0 : Int) / -4 == 0
|
||||
#guard ( 0 : Int) % -4 == 0
|
||||
|
||||
-- Euclidean division tests
|
||||
#guard ( 4 : Int) / 3 == 1
|
||||
#guard ( 4 : Int) % 3 == 1
|
||||
#guard ( 5 : Int) / 3 == 1
|
||||
#guard ( 5 : Int) % 3 == 2
|
||||
#guard ( 6 : Int) / 3 == 2
|
||||
#guard ( 6 : Int) % 3 == 0
|
||||
#guard ( 7 : Int) / 4 == 1
|
||||
#guard ( 7 : Int) % 4 == 3
|
||||
|
||||
#guard ( 4 : Int) / -3 == -1
|
||||
#guard ( 4 : Int) % -3 == 1
|
||||
#guard ( 5 : Int) / -3 == -1
|
||||
#guard ( 5 : Int) % -3 == 2
|
||||
#guard ( 6 : Int) / -3 == -2
|
||||
#guard ( 6 : Int) % -3 == 0
|
||||
#guard ( 7 : Int) / -4 == -1
|
||||
#guard ( 7 : Int) % -4 == 3
|
||||
|
||||
#guard (-4 : Int) / 3 == -2
|
||||
#guard (-4 : Int) % 3 == 2
|
||||
#guard (-5 : Int) / 3 == -2
|
||||
#guard (-5 : Int) % 3 == 1
|
||||
#guard (-6 : Int) / 3 == -2
|
||||
#guard (-6 : Int) % 3 == 0
|
||||
#guard (-7 : Int) / 4 == -2
|
||||
#guard (-7 : Int) % 4 == 1
|
||||
|
||||
#guard (-4 : Int) / -3 == 2
|
||||
#guard (-4 : Int) % -3 == 2
|
||||
#guard (-5 : Int) / -3 == 2
|
||||
#guard (-5 : Int) % -3 == 1
|
||||
#guard (-6 : Int) / -3 == 2
|
||||
#guard (-6 : Int) % -3 == 0
|
||||
#guard (-7 : Int) / -4 == 2
|
||||
#guard (-7 : Int) % -4 == 1
|
||||
|
||||
-- Basic big integer tests
|
||||
#guard let n : Int := 0; let d : Int := 2^64; n / d = 0 ∧ n % d = n
|
||||
#guard let n : Int := 1; let d : Int := 2^64; n / d = 0 ∧ n % d = n
|
||||
#guard let n : Int := -1; let d : Int := 2^64; n / d = -1 ∧ n % d = (d + n)
|
||||
#guard let n : Int := 2^128; let d : Int := 3; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d
|
||||
#guard let n : Int := 2^128; let d : Int := 2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d
|
||||
#guard let n : Int := -2^128; let d : Int := 2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d
|
||||
#guard let n : Int := 2^128; let d : Int := -2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d.natAbs
|
||||
#guard let n : Int := -2^128; let d : Int := -2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d.natAbs
|
||||
#guard let n : Int := 2^128+7; let d : Int := 2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d
|
||||
#guard let n : Int := -2^128+3; let d : Int := 2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d
|
||||
#guard let n : Int := 2^128+2; let d : Int := -2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d.natAbs
|
||||
#guard let n : Int := -2^128+2; let d : Int := -2^64; d * (n / d) + n % d = n ∧ n % d ≥ 0 ∧ n % d < d.natAbs
|
||||
0
tests/lean/int_div_mod.lean.expected.out
Normal file
0
tests/lean/int_div_mod.lean.expected.out
Normal file
Loading…
Add table
Reference in a new issue