lean4-htt/src/runtime/mpz.cpp
2021-11-29 16:01:07 -08:00

578 lines
12 KiB
C++

/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <memory>
#include <string>
#include <cstring>
#include "runtime/sstream.h"
#include "runtime/alloc.h"
#include "runtime/thread.h"
#include "runtime/mpz.h"
#include "runtime/debug.h"
namespace lean {
/***** GMP VERSION ******/
#ifdef LEAN_USE_GMP
mpz::mpz() {
mpz_init(m_val);
}
mpz::mpz(mpz_t v) {
mpz_init(m_val);
mpz_set(m_val, v);
}
mpz::mpz(char const * v) {
mpz_init_set_str(m_val, const_cast<char*>(v), 10);
}
mpz::mpz(unsigned int v) {
mpz_init_set_ui(m_val, v);
}
mpz::mpz(int v) {
mpz_init_set_si(m_val, v);
}
mpz::mpz(uint64 v):
mpz(static_cast<unsigned>(v)) {
mpz tmp(static_cast<unsigned>(v >> 32));
mpz_mul_2exp(tmp.m_val, tmp.m_val, 32);
mpz_add(m_val, m_val, tmp.m_val);
}
mpz::mpz(int64 v) {
uint64 w;
if (v < 0) w = -v;
else w = v;
mpz_init_set_ui(m_val, static_cast<unsigned>(w));
mpz tmp(static_cast<unsigned>(w >> 32));
mpz_mul_2exp(tmp.m_val, tmp.m_val, 32);
mpz_add(m_val, m_val, tmp.m_val);
if (v < 0)
mpz_neg(m_val, m_val);
}
mpz::mpz(mpz const & s) {
mpz_init_set(m_val, s.m_val);
}
mpz::mpz(mpz && s):mpz() {
mpz_swap(m_val, s.m_val);
}
mpz::~mpz() {
mpz_clear(m_val);
}
void mpz::set(mpz_t r) const {
mpz_set(r, m_val);
}
void swap(mpz & a, mpz & b) {
mpz_swap(a.m_val, b.m_val);
}
int mpz::sgn() const {
return mpz_sgn(m_val);
}
bool mpz::is_int() const {
return mpz_fits_sint_p(m_val) != 0;
}
bool mpz::is_unsigned_int() const {
return mpz_fits_uint_p(m_val) != 0;
}
bool mpz::is_size_t() const {
// GMP only features `fits` functions up to `unsigned long`, which is smaller than `size_t` on Windows.
// So we directly count the number of mpz words instead.
static_assert(sizeof(size_t) == sizeof(mp_limb_t), "GMP word size should be equal to system word size");
return is_nonneg() && mpz_size(m_val) <= 1;
}
int mpz::get_int() const {
lean_assert(is_int());
return static_cast<int>(mpz_get_si(m_val));
}
unsigned int mpz::get_unsigned_int() const {
lean_assert(is_unsigned_int());
return static_cast<unsigned>(mpz_get_ui(m_val));
}
size_t mpz::get_size_t() const {
// GMP only features accessors up to `unsigned long`, which is smaller than `size_t` on Windows.
// So we directly access the lowest mpz word instead.
static_assert(sizeof(size_t) == sizeof(mp_limb_t), "GMP word size should be equal system word size");
// NOTE: mpz_getlimbn returns 0 if the index is out of range (i.e. `m_val == 0`)
return static_cast<size_t>(mpz_getlimbn(m_val, 0));
}
mpz & mpz::operator=(mpz const & v) {
mpz_set(m_val, v.m_val); return *this;
}
mpz & mpz::operator=(char const * v) {
mpz_set_str(m_val, v, 10); return *this;
}
mpz & mpz::operator=(unsigned int v) {
mpz_set_ui(m_val, v); return *this;
}
mpz & mpz::operator=(int v) {
mpz_set_si(m_val, v); return *this;
}
int cmp(mpz const & a, mpz const & b) {
return mpz_cmp(a.m_val, b.m_val);
}
int cmp(mpz const & a, unsigned b) {
return mpz_cmp_ui(a.m_val, b);
}
int cmp(mpz const & a, int b) {
return mpz_cmp_si(a.m_val, b);
}
mpz & mpz::operator+=(mpz const & o) { mpz_add(m_val, m_val, o.m_val); return *this; }
mpz & mpz::operator+=(unsigned u) { mpz_add_ui(m_val, m_val, u); return *this; }
mpz & mpz::operator+=(int u) { if (u >= 0) mpz_add_ui(m_val, m_val, u); else mpz_sub_ui(m_val, m_val, -u); return *this; }
mpz & mpz::operator-=(mpz const & o) { mpz_sub(m_val, m_val, o.m_val); return *this; }
mpz & mpz::operator-=(unsigned u) { mpz_sub_ui(m_val, m_val, u); return *this; }
mpz & mpz::operator-=(int u) { if (u >= 0) mpz_sub_ui(m_val, m_val, u); else mpz_add_ui(m_val, m_val, -u); return *this; }
mpz & mpz::operator*=(mpz const & o) { mpz_mul(m_val, m_val, o.m_val); return *this; }
mpz & mpz::operator*=(unsigned u) { mpz_mul_ui(m_val, m_val, u); return *this; }
mpz & mpz::operator*=(int u) { mpz_mul_si(m_val, m_val, u); return *this; }
mpz & mpz::operator/=(mpz const & o) { mpz_tdiv_q(m_val, m_val, o.m_val); return *this; }
mpz & mpz::operator/=(unsigned u) { mpz_tdiv_q_ui(m_val, m_val, u); return *this; }
mpz rem(mpz const & a, mpz const & b) { mpz r; mpz_tdiv_r(r.m_val, a.m_val, b.m_val); return r; }
mpz mpz::pow(unsigned int exp) const {
mpz r;
mpz_pow_ui(r.m_val, m_val, exp);
return r;
}
size_t mpz::log2() const {
if (is_nonpos())
return 0;
size_t r = mpz_sizeinbase(m_val, 2);
lean_assert(r > 0);
return r - 1;
}
mpz operator%(mpz const & a, mpz const & b) {
return rem(a, b);
}
mpz & mpz::operator&=(mpz const & o) {
mpz_and(m_val, m_val, o.m_val);
return *this;
}
mpz & mpz::operator|=(mpz const & o) {
mpz_ior(m_val, m_val, o.m_val);
return *this;
}
mpz & mpz::operator^=(mpz const & o) {
mpz_xor(m_val, m_val, o.m_val);
return *this;
}
void mul2k(mpz & a, mpz const & b, unsigned k) {
mpz_mul_2exp(a.m_val, b.m_val, k);
}
void div2k(mpz & a, mpz const & b, unsigned k) {
mpz_tdiv_q_2exp(a.m_val, b.m_val, k);
}
void mod2k(mpz & a, mpz const & b, unsigned k) {
mpz_tdiv_r_2exp(a.m_val, b.m_val, k);
}
void power(mpz & a, mpz const & b, unsigned k) {
mpz_pow_ui(a.m_val, b.m_val, k);
}
void gcd(mpz & g, mpz const & a, mpz const & b) {
mpz_gcd(g.m_val, a.m_val, b.m_val);
}
void display(std::ostream & out, __mpz_struct const * v) {
size_t sz = mpz_sizeinbase(v, 10) + 2;
if (sz < 1024) {
char buffer[1024];
mpz_get_str(buffer, 10, v);
out << buffer;
} else {
std::unique_ptr<char> buffer(new char[sz]);
mpz_get_str(buffer.get(), 10, v);
out << buffer.get();
}
}
std::ostream & operator<<(std::ostream & out, mpz const & v) {
display(out, v.m_val);
return out;
}
#else
/***** NON GMP VERSION ******/
void mpz::allocate(size_t s) {
m_sign = false;
m_size = s;
m_digits = static_cast<mpn_digit*>(alloc(s * sizeof(mpn_digit)));
}
void mpz::init() {
allocate(1);
m_digits[0] = 0;
}
void mpz::init_str(char const * v) {
init();
char const * str = v;
bool sign = false;
while (str[0] == ' ') ++str;
if (str[0] == '-')
sign = true;
while (str[0]) {
if ('0' <= str[0] && str[0] <= '9') {
operator*=(10);
operator+=(static_cast<unsigned>(str[0] - '0'));
}
++str;
}
if (sign)
neg();
}
void mpz::init_uint(unsigned int v) {
allocate(1);
m_digits[0] = v;
}
void mpz::init_int(int v) {
allocate(1);
if (v < 0) {
m_sign = true;
m_digits[0] = -v;
} else {
m_digits[0] = v;
}
}
void mpz::init_uint64(uint64 v) {
if (v <= std::numeric_limits<unsigned>::max()) {
allocate(1);
m_digits[0] = v;
} else {
allocate(2);
(reinterpret_cast<uint64*>(m_digits))[0] = v;
}
}
void mpz::init_int64(int64 v) {
// TODO
lean_unreachable();
}
void mpz::init_mpz(mpz const & v) {
m_sign = v.m_sign;
m_size = v.m_size;
m_digits = static_cast<mpn_digit*>(alloc(m_size * sizeof(mpn_digit)));
memcpy(m_digits, v.m_digits, m_size * sizeof(mpn_digit));
}
mpz::mpz() {
init();
}
mpz::mpz(char const * v) {
init_str(v);
}
mpz::mpz(unsigned int v) {
init_uint(v);
}
mpz::mpz(int v) {
init_int(v);
}
mpz::mpz(uint64 v) {
init_uint64(v);
}
mpz::mpz(int64 v) {
init_int64(v);
}
mpz::mpz(mpz const & s) {
init_mpz(s);
}
mpz::mpz(mpz && s):
m_sign(s.m_sign),
m_size(s.m_size),
m_digits(s.m_digits) {
s.m_digits = nullptr;
}
mpz::~mpz() {
if (m_digits)
dealloc(m_digits, sizeof(mpn_digit)*m_size);
}
void swap(mpz & a, mpz & b) {
std::swap(a.m_sign, b.m_sign);
std::swap(a.m_size, b.m_size);
std::swap(a.m_digits, b.m_digits);
}
int mpz::sgn() const {
if (m_size > 1) {
return m_sign ? -1 : 1;
} else {
if (m_digits[0] == 0)
return 0;
else
return m_sign ? -1 : 1;
}
}
bool mpz::is_int() const {
// TODO
lean_unreachable();
}
bool mpz::is_unsigned_int() const {
return m_size == 1 && !m_sign;
}
bool mpz::is_size_t() const {
if (sizeof(size_t) == 8) {
return m_size <= 2 && !m_sign;
} else {
return m_size == 1 && !m_sign;
}
}
int mpz::get_int() const {
// TODO
lean_unreachable();
}
unsigned int mpz::get_unsigned_int() const {
// TODO
lean_unreachable();
}
size_t mpz::get_size_t() const {
// TODO
lean_unreachable();
}
mpz & mpz::operator=(mpz const & v) {
dealloc(m_digits, sizeof(mpn_digit)*m_size);
init_mpz(v);
return *this;
}
mpz & mpz::operator=(char const * v) {
dealloc(m_digits, sizeof(mpn_digit)*m_size);
init_str(v);
return *this;
}
mpz & mpz::operator=(unsigned int v) {
dealloc(m_digits, sizeof(mpn_digit)*m_size);
init_uint(v);
return *this;
}
mpz & mpz::operator=(int v) {
dealloc(m_digits, sizeof(mpn_digit)*m_size);
init_int(v);
return *this;
}
int cmp(mpz const & a, mpz const & b) {
if (a.m_sign) {
if (b.m_sign) {
return mpn_compare(b.m_digits, b.m_size, a.m_digits, a.m_size);
} else {
return -1; // `a` is negative and `b` is nonnegative
}
} else {
if (b.m_sign) {
return 1; // `a` is nonnegative and `b` is negative
} else {
return mpn_compare(a.m_digits, a.m_size, b.m_digits, b.m_size);
}
}
}
int cmp(mpz const & a, unsigned b) {
if (a.m_sign) {
return -1;
} else {
return mpn_compare(a.m_digits, a.m_size, &b, 1);
}
}
int cmp(mpz const & a, int b) {
// TODO
lean_unreachable();
}
mpz & mpz::operator+=(mpz const & o) {
// TODO
lean_unreachable();
}
mpz & mpz::operator+=(unsigned u) {
// TODO
lean_unreachable();
}
mpz & mpz::operator+=(int u) {
// TODO
lean_unreachable();
}
mpz & mpz::operator-=(mpz const & o) {
// TODO
lean_unreachable();
}
mpz & mpz::operator-=(unsigned u) {
// TODO
lean_unreachable();
}
mpz & mpz::operator-=(int u) {
// TODO
lean_unreachable();
}
mpz & mpz::operator*=(mpz const & o) {
// TODO
lean_unreachable();
}
mpz & mpz::operator*=(unsigned u) {
// TODO
lean_unreachable();
}
mpz & mpz::operator*=(int u) {
// TODO
lean_unreachable();
}
mpz & mpz::operator/=(mpz const & o) {
// TODO
lean_unreachable();
}
mpz & mpz::operator/=(unsigned u) {
// TODO
lean_unreachable();
}
mpz rem(mpz const & a, mpz const & b) {
// TODO
lean_unreachable();
}
mpz mpz::pow(unsigned int exp) const {
// TODO
lean_unreachable();
}
size_t mpz::log2() const {
// TODO
lean_unreachable();
}
mpz operator%(mpz const & a, mpz const & b) {
// TODO
lean_unreachable();
}
mpz & mpz::operator&=(mpz const & o) {
// TODO
lean_unreachable();
}
mpz & mpz::operator|=(mpz const & o) {
// TODO
lean_unreachable();
}
mpz & mpz::operator^=(mpz const & o) {
// TODO
lean_unreachable();
}
void mul2k(mpz & a, mpz const & b, unsigned k) {
// TODO
lean_unreachable();
}
void div2k(mpz & a, mpz const & b, unsigned k) {
// TODO
lean_unreachable();
}
void mod2k(mpz & a, mpz const & b, unsigned k) {
// TODO
lean_unreachable();
}
void power(mpz & a, mpz const & b, unsigned k) {
// TODO
lean_unreachable();
}
void gcd(mpz & g, mpz const & a, mpz const & b) {
// TODO
lean_unreachable();
}
std::ostream & operator<<(std::ostream & out, mpz const & v) {
// TODO
lean_unreachable();
}
#endif
std::string mpz::to_string() const {
std::ostringstream out;
out << *this;
return out.str();
}
}
void print(lean::mpz const & n) { std::cout << n << std::endl; }