diff --git a/src/runtime/mpz.cpp b/src/runtime/mpz.cpp index 2132d306c0..4b671aeeeb 100644 --- a/src/runtime/mpz.cpp +++ b/src/runtime/mpz.cpp @@ -239,13 +239,13 @@ std::ostream & operator<<(std::ostream & out, mpz const & v) { /***** NON GMP VERSION ******/ void mpz::allocate(size_t s) { - m_sign = false; m_size = s; m_digits = static_cast(alloc(s * sizeof(mpn_digit))); } void mpz::init() { allocate(1); + m_sign = false; m_digits[0] = 0; } @@ -269,15 +269,17 @@ void mpz::init_str(char const * v) { void mpz::init_uint(unsigned int v) { allocate(1); + m_sign = false; m_digits[0] = v; } void mpz::init_int(int v) { allocate(1); if (v < 0) { - m_sign = true; + m_sign = true; m_digits[0] = -v; } else { + m_sign = false; m_digits[0] = v; } } @@ -285,9 +287,11 @@ void mpz::init_int(int v) { void mpz::init_uint64(uint64 v) { if (v <= std::numeric_limits::max()) { allocate(1); + m_sign = false; m_digits[0] = v; } else { allocate(2); + m_sign = false; (reinterpret_cast(m_digits))[0] = v; } } @@ -442,38 +446,72 @@ int cmp(mpz const & a, unsigned b) { } int cmp(mpz const & a, int b) { + if (a.m_sign) { + if (b < 0) { + unsigned b1 = -b; + return mpn_compare(&b1, 1, a.m_digits, a.m_size); + } else { + return -1; + } + } else { + if (b < 0) { + return 1; + } else { + unsigned b1 = b; + return mpn_compare(a.m_digits, a.m_size, &b1, 1); + } + } +} + +void mpz::set(size_t sz, mpn_digit const * digits) { + while (sz > 1 && digits[sz - 1] == 0) + sz--; + if (sz != m_size) { + dealloc(m_digits, sizeof(mpn_digit)*m_size); + allocate(sz); + } + memcpy(m_digits, digits, sizeof(mpn_digit)*sz); +} + +mpz & mpz::add(bool sign, size_t sz, mpn_digit const * digits) { // TODO lean_unreachable(); } mpz & mpz::operator+=(mpz const & o) { - // TODO - lean_unreachable(); + return add(o.m_sign, o.m_size, o.m_digits); } mpz & mpz::operator+=(unsigned u) { - // TODO - lean_unreachable(); + return add(false, 1, &u); } mpz & mpz::operator+=(int u) { - // TODO - lean_unreachable(); + if (u < 0) { + unsigned u1 = -u; + return add(true, 1, &u1); + } else { + unsigned u1 = u; + return add(false, 1, &u1); + } } mpz & mpz::operator-=(mpz const & o) { - // TODO - lean_unreachable(); + return add(!o.m_sign, o.m_size, o.m_digits); } mpz & mpz::operator-=(unsigned u) { - // TODO - lean_unreachable(); + return add(true, 1, &u); } mpz & mpz::operator-=(int u) { - // TODO - lean_unreachable(); + if (u < 0) { + unsigned u1 = -u; + return add(false, 1, &u1); + } else { + unsigned u1 = u; + return add(true, 1, &u1); + } } mpz & mpz::operator*=(mpz const & o) { @@ -506,9 +544,17 @@ mpz rem(mpz const & a, mpz const & b) { lean_unreachable(); } -mpz mpz::pow(unsigned int exp) const { - // TODO - lean_unreachable(); +mpz mpz::pow(unsigned int p) const { + unsigned mask = 1; + mpz power(p); + mpz result(1); + while (mask <= p) { + if (mask & p) + result *= power; + power *= power; + mask = mask << 1; + } + return result; } size_t mpz::log2() const { @@ -552,8 +598,8 @@ void mod2k(mpz & a, mpz const & b, unsigned k) { } void power(mpz & a, mpz const & b, unsigned k) { - // TODO - lean_unreachable(); + a = b; + a.pow(k); } void gcd(mpz & g, mpz const & a, mpz const & b) { diff --git a/src/runtime/mpz.h b/src/runtime/mpz.h index d88ea6f63d..f3f40e3228 100644 --- a/src/runtime/mpz.h +++ b/src/runtime/mpz.h @@ -39,6 +39,8 @@ class mpz { void init_uint64(uint64 v); void init_int64(int64 v); void init_mpz(mpz const & v); + void set(size_t sz, mpn_digit const * digits); + mpz & add(bool sign, size_t sz, mpn_digit const * digits); #endif public: mpz();