diff --git a/src/util/mpbq.cpp b/src/util/mpbq.cpp index c628eb9451..f039e12d46 100644 --- a/src/util/mpbq.cpp +++ b/src/util/mpbq.cpp @@ -15,41 +15,196 @@ void mpbq::normalize() { m_k = 0; return; } - unsigned k = m_num.power_of_two_multiple(); if (k > m_k) k = m_k; - m_num.div2k(k); + div2k(m_num, m_num, k); m_k -= k; } - int cmp(mpbq const & a, mpbq const & b) { + static thread_local mpz tmp; if (a.m_k == b.m_k) return cmp(a.m_num, b.m_num); else if (a.m_k < b.m_k) { - mpz tmp(a.m_num); - tmp.mul2k(b.m_k - a.m_k); + mul2k(tmp, a.m_num, b.m_k - a.m_k); return cmp(tmp, b.m_num); } else { lean_assert(a.m_k > b.m_k); - mpz tmp(b.m_num); - tmp.mul2k(a.m_k - b.m_k); + mul2k(tmp, b.m_num, a.m_k - b.m_k); return cmp(a.m_num, tmp); } } int cmp(mpbq const & a, mpz const & b) { + static thread_local mpz tmp; if (a.m_k == 0) return cmp(a.m_num, b); else { - mpz tmp(b); - tmp.mul2k(a.m_k); + mul2k(tmp, b, a.m_k); return cmp(a.m_num, tmp); } } +mpbq & mpbq::operator+=(mpbq const & a) { + if (m_k == a.m_k) { + m_num += a.m_num; + } + else if (m_k < a.m_k) { + mul2k(m_num, m_num, a.m_k - m_k); + m_k = a.m_k; + m_num += a.m_num; + } + else { + lean_assert(m_k > a.m_k); + static thread_local mpz tmp; + mul2k(tmp, a.m_num, m_k - a.m_k); + m_num += tmp; + } + normalize(); + return *this; +} + +template +mpbq & mpbq::add_int(T const & a) { + if (m_k == 0) { + m_num += a; + } + else { + lean_assert(m_k > 0); + static thread_local mpz tmp; + tmp = a; + mul2k(tmp, tmp, m_k); + m_num += tmp; + } + normalize(); + return *this; +} +mpbq & mpbq::operator+=(unsigned a) { return add_int(a); } +mpbq & mpbq::operator+=(int a) { return add_int(a); } + +mpbq & mpbq::operator-=(mpbq const & a) { + if (m_k == a.m_k) { + m_num -= a.m_num; + } + else if (m_k < a.m_k) { + mul2k(m_num, m_num, a.m_k - m_k); + m_k = a.m_k; + m_num -= a.m_num; + } + else { + lean_assert(m_k > a.m_k); + static thread_local mpz tmp; + mul2k(tmp, a.m_num, m_k - a.m_k); + m_num -= tmp; + } + normalize(); + return *this; +} + +template +mpbq & mpbq::sub_int(T const & a) { + if (m_k == 0) { + m_num -= a; + } + else { + lean_assert(m_k > 0); + static thread_local mpz tmp; + tmp = a; + mul2k(tmp, tmp, m_k); + m_num -= tmp; + } + normalize(); + return *this; +} +mpbq & mpbq::operator-=(unsigned a) { return sub_int(a); } +mpbq & mpbq::operator-=(int a) { return sub_int(a); } + +mpbq & mpbq::operator*=(mpbq const & a) { + m_num *= a.m_num; + if (m_k == 0 || a.m_k == 0) { + m_k += a.m_k; + normalize(); + } + else { + m_k += a.m_k; + } + return *this; +} + +template +mpbq & mpbq::mul_int(T const & a) { + m_num *= a; + normalize(); + return *this; +} +mpbq & mpbq::operator*=(unsigned a) { return mul_int(a); } +mpbq & mpbq::operator*=(int a) { return mul_int(a); } + +int mpbq::magnitude_lb() const { + int s = m_num.sgn(); + if (s < 0) { + return m_num.mlog2() - m_k + 1; + } + else if (s == 0) { + return 0; + } + else { + lean_assert(s > 0); + return m_num.log2() - m_k; + } +} + +int mpbq::magnitude_ub() const { + int s = m_num.sgn(); + if (s < 0) { + return m_num.mlog2() - m_k; + } + else if (s == 0) { + return 0; + } + else { + lean_assert(s > 0); + return m_num.log2() - m_k + 1; + } +} + +void mul2(mpbq & a) { + if (a.m_k == 0) { + mul2k(a.m_num, a.m_num, 1); + } + else { + a.m_k--; + } +} + +void mul2k(mpbq & a, unsigned k) { + if (k == 0) + return; + if (a.m_k < k) { + mul2k(a.m_num, a.m_num, k - a.m_k); + a.m_k = 0; + } + else { + lean_assert(a.m_k >= k); + a.m_k -= k; + } +} + +std::ostream & operator<<(std::ostream & out, mpbq const & v) { + if (v.m_k == 0) { + out << v.m_num; + } + else if (v.m_k == 1) { + out << v.m_num << "/2"; + } + else { + out << v.m_num << "/2^" << v.m_k; + } + return out; +} } +void pp(lean::mpbq const & n) { std::cout << n << std::endl; } diff --git a/src/util/mpbq.h b/src/util/mpbq.h index 439e4f9289..c7441d2ff2 100644 --- a/src/util/mpbq.h +++ b/src/util/mpbq.h @@ -14,6 +14,9 @@ class mpbq { mpz m_num; unsigned m_k; void normalize(); + template mpbq & add_int(T const & a); + template mpbq & sub_int(T const & a); + template mpbq & mul_int(T const & a); public: mpbq():m_k(0) {} mpbq(mpbq const & v):m_num(v.m_num), m_k(v.m_k) {} @@ -23,6 +26,11 @@ public: mpbq(int n, unsigned k):m_num(n), m_k(k) { normalize(); } ~mpbq() {} + mpbq & operator=(mpbq const & v) { m_num = v.m_num; m_k = v.m_k; return *this; } + mpbq & operator=(mpbq && v) { swap(v); return *this; } + mpbq & operator=(unsigned int v) { m_num = v; m_k = 0; return *this; } + mpbq & operator=(int v) { m_num = v; m_k = 0; return *this; } + void swap(mpbq & o) { m_num.swap(o.m_num); std::swap(m_k, o.m_k); } unsigned hash() const { return m_num.hash(); } @@ -139,6 +147,43 @@ public: mpbq & operator--() { return operator-=(1); } mpbq operator--(int) { mpbq r(*this); --(*this); return r; } + + /** + \brief Return the magnitude of a = b/2^k. + It is defined as: + a == 0 -> 0 + a > 0 -> log2(b) - k Note that 2^{log2(b) - k} <= a <= 2^{log2(b) - k + 1} + a < 0 -> mlog2(b) - k + 1 Note that -2^{mlog2(b) - k + 1} <= a <= -2^{mlog2(b) - k} + + Remark: mlog2(b) = log2(-b) + + Examples: + + 5/2^3 log2(5) - 3 = -1 + 21/2^2 log2(21) - 2 = 2 + -3/2^4 log2(3) - 4 + 1 = -2 + */ + int magnitude_lb() const; + + /** + \brief Similar to magnitude_lb + + a == 0 -> 0 + a > 0 -> log2(b) - k + 1 a <= 2^{log2(b) - k + 1} + a < 0 -> mlog2(b) - k a <= -2^{mlog2(b) - k} + */ + int magnitude_ub() const; + + // a <- a*2 + friend void mul2(mpbq & a); + // a <- a*2^k + friend void mul2k(mpbq & a, unsigned k); + + // a <- b * 2^k + friend void mul2k(mpbq & a, mpbq const & b, unsigned k) { a = b; mul2k(a, k); } + // a <- b / 2^k + friend void div2k(mpbq & a, mpbq const & b, unsigned k); + friend std::ostream & operator<<(std::ostream & out, mpbq const & v); }; diff --git a/src/util/mpq.cpp b/src/util/mpq.cpp index c52496c4f9..5180c609b1 100644 --- a/src/util/mpq.cpp +++ b/src/util/mpq.cpp @@ -62,6 +62,7 @@ std::ostream & operator<<(std::ostream & out, mpq const & v) { return out; } -void pp(mpq const & v) { std::cout << v << std::endl; } - } + +void pp(lean::mpq const & v) { std::cout << v << std::endl; } + diff --git a/src/util/mpq.h b/src/util/mpq.h index a4d317f4f9..041e38e9ac 100644 --- a/src/util/mpq.h +++ b/src/util/mpq.h @@ -15,8 +15,13 @@ class mpq { static mpz_t const & zval(mpz const & v) { return v.m_val; } static mpz_t & zval(mpz & v) { return v.m_val; } public: + void swap(mpq & v) { mpq_swap(m_val, v.m_val); } + void swap_numerator(mpz & v) { mpz_swap(mpq_numref(m_val), v.m_val); mpq_canonicalize(m_val); } + void swap_denominator(mpz & v) { mpz_swap(mpq_denref(m_val), v.m_val); mpq_canonicalize(m_val); } + mpq & operator=(mpz const & v) { mpq_set_z(m_val, v.m_val); return *this; } mpq & operator=(mpq const & v) { mpq_set(m_val, v.m_val); return *this; } + mpq & operator=(mpq && v) { swap(v); return *this; } mpq & operator=(char const * v) { mpq_set_str(m_val, v, 10); return *this; } mpq & operator=(unsigned long int v) { mpq_set_ui(m_val, v, 1u); return *this; } mpq & operator=(long int v) { mpq_set_si(m_val, v, 1); return *this; } @@ -39,10 +44,6 @@ public: mpq(double v):mpq() { mpq_set_d(m_val, v); } ~mpq() { mpq_clear(m_val); } - void swap(mpq & v) { mpq_swap(m_val, v.m_val); } - void swap_numerator(mpz & v) { mpz_swap(mpq_numref(m_val), v.m_val); mpq_canonicalize(m_val); } - void swap_denominator(mpz & v) { mpz_swap(mpq_denref(m_val), v.m_val); mpq_canonicalize(m_val); } - unsigned hash() const { return static_cast(mpz_get_si(mpq_numref(m_val))); } int sgn() const { return mpq_sgn(m_val); } diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index 961222e078..98c9df14b3 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -72,3 +72,5 @@ std::ostream & operator<<(std::ostream & out, mpz const & v) { } } + +void pp(lean::mpz const & n) { std::cout << n << std::endl; } diff --git a/src/util/mpz.h b/src/util/mpz.h index e0ad7212ad..f1bf933352 100644 --- a/src/util/mpz.h +++ b/src/util/mpz.h @@ -59,6 +59,14 @@ public: unsigned long int get_unsigned_long_int() const { lean_assert(is_unsigned_long_int()); return mpz_get_ui(m_val); } unsigned int get_unsigned_int() const { lean_assert(is_unsigned_int()); return static_cast(get_unsigned_long_int()); } + mpz & operator=(mpz const & v) { mpz_set(m_val, v.m_val); return *this; } + mpz & operator=(mpz && v) { swap(v); return *this; } + mpz & operator=(char const * v) { mpz_set_str(m_val, v, 10); return *this; } + mpz & operator=(unsigned long int v) { mpz_set_ui(m_val, v); return *this; } + mpz & operator=(long int v) { mpz_set_si(m_val, v); return *this; } + mpz & operator=(unsigned int v) { return operator=(static_cast(v)); } + mpz & operator=(int v) { return operator=(static_cast(v)); } + friend int cmp(mpz const & a, mpz const & b) { return mpz_cmp(a.m_val, b.m_val); } friend int cmp(mpz const & a, unsigned b) { return mpz_cmp_ui(a.m_val, b); } friend int cmp(mpz const & a, int b) { return mpz_cmp_si(a.m_val, b); } @@ -164,11 +172,11 @@ public: // this <- this - a*b void submul(mpz const & a, mpz const & b) { mpz_submul(m_val, a.m_val, b.m_val); } - // this <- this * 2^k - void mul2k(unsigned k) { mpz_mul_2exp(m_val, m_val, k); } - // this <- this / 2^k - void div2k(unsigned k) { mpz_tdiv_q_2exp(m_val, m_val, k); } - + // a <- b * 2^k + friend void mul2k(mpz & a, mpz const & b, unsigned k) { mpz_mul_2exp(a.m_val, b.m_val, k); } + // a <- b / 2^k + friend void div2k(mpz & a, mpz const & b, unsigned k) { mpz_tdiv_q_2exp(a.m_val, b.m_val, k); } + /** \brief Return the position of the most significant bit. Return 0 if the number is negative